This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git
The following commit(s) were added to refs/heads/master by this push: new 9e254ba [SPARK-48754] Adding tests, structure, best practices to productionize code base 9e254ba is described below commit 9e254ba7583a06bfee2acc4766fd4dde00198f3c Author: Mathias Schwarz <165780420+mathiasschw...@users.noreply.github.com> AuthorDate: Tue Jul 2 17:28:39 2024 +0900 [SPARK-48754] Adding tests, structure, best practices to productionize code base ### What changes were proposed in this pull request? This change contains several improvements that all aim to increase the code quality of the spark-connect-go repo. All in all, these changes push the repo much closer to best practice Go without major semantic changes. The changes fall in these categories: - Improve unit test coverage by about 30 percentage points - Decoupled the components in the sql package to make them individually testable and only depend on each others interfaces rather than implementation - Added context propagation to the code base. This allows users of the library to set connection timeouts, auth headers etc. - Added method/function level comments where they were missing for public functions - Removed the global var builder 'entry point' and replaced it by a normal constructor so that each builder is simply new instead of the previous copy semantics - Added a simple error hierarchy so that errors can be handled by looking at error types instead of just string values - Created constructors with required params for all structs instead of having the users create structs internally - Removed a strange case of panic'ing the the whole process if some input was invalid - Updated documentation and examples to reflect these changes ### Why are the changes needed? These changes aim (along with subsequent changes) to get this code base to a point where it will eventually be fit for production use, something that is strictly forbidden right now ### Does this PR introduce _any_ user-facing change? The PR as much as possible aims to not change the API but in a few cases this has not been possible. In particular, functions that eventually result in an outbound call to GRPC now take a context parameter. This is necessary and required for real production grade code. In addition, the builder is instantiated slightly differently (actually instantiated instead of being a global var) but the API for it otherwise remains. ### How was this patch tested? All the code that was touch, has gotten some degree of unit testing that at least ensures coverage as well as checking of output Closes #20 from mathiasschw-db/mathiasschw-db/productionize. Authored-by: Mathias Schwarz <165780420+mathiasschw...@users.noreply.github.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .github/workflows/build.yml | 4 +- client/channel/channel.go | 24 ++-- client/channel/channel_test.go | 12 +- client/channel/compat.go | 6 + client/{sql/row.go => sparkerrors/errors.go} | 36 ++++-- client/sparkerrors/errors_test.go | 17 +++ client/sql/dataframe.go | 134 +++++++++++++--------- client/sql/dataframe_test.go | 78 ++++++++++++- client/sql/dataframereader.go | 34 +++--- client/sql/dataframereader_test.go | 26 +++++ client/sql/dataframewriter.go | 35 ++++-- client/sql/dataframewriter_test.go | 55 ++++++++- client/sql/datatype.go | 17 +-- client/sql/executeplanclient.go | 37 ++++++ client/sql/mocks_test.go | 112 ++++++++++++++++++ client/sql/plan_test.go | 13 +++ client/sql/row.go | 17 ++- client/sql/row_test.go | 25 ++++ client/sql/sparksession.go | 82 +++++-------- client/sql/sparksession_test.go | 111 ++++++++++++++++++ client/sql/structtype.go | 2 + cmd/spark-connect-example-raw-grpc-client/main.go | 12 +- cmd/spark-connect-example-spark-session/main.go | 82 +++++++------ go.mod | 2 +- quick-start.md | 57 ++++++--- 25 files changed, 790 insertions(+), 240 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9d0c13b..b6b1c4c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -45,8 +45,8 @@ jobs: ref: master - name: Install Golang run: | - curl -LO https://go.dev/dl/go1.19.9.linux-amd64.tar.gz - sudo tar -C /usr/local -xzf go1.19.9.linux-amd64.tar.gz + curl -LO https://go.dev/dl/go1.21.11.linux-amd64.tar.gz + sudo tar -C /usr/local -xzf go1.21.11.linux-amd64.tar.gz - name: Install Buf run: | # See more in "Installation" https://docs.buf.build/installation#tarball diff --git a/client/channel/channel.go b/client/channel/channel.go index 6cf7696..b5d459d 100644 --- a/client/channel/channel.go +++ b/client/channel/channel.go @@ -17,6 +17,7 @@ package channel import ( + "context" "crypto/tls" "crypto/x509" "errors" @@ -26,6 +27,7 @@ import ( "strconv" "strings" + "github.com/apache/spark-connect-go/v1/client/sparkerrors" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -35,11 +37,11 @@ import ( // Reserved header parameters that must not be injected as variables. var reservedParams = []string{"user_id", "token", "use_ssl"} -// The ChannelBuilder is used to parse the different parameters of the connection +// Builder is used to parse the different parameters of the connection // string according to the specification documented here: // // https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md -type ChannelBuilder struct { +type Builder struct { Host string Port int Token string @@ -47,10 +49,10 @@ type ChannelBuilder struct { Headers map[string]string } -// Finalizes the creation of the gprc.ClientConn by creating a GRPC channel +// Build finalizes the creation of the gprc.ClientConn by creating a GRPC channel // with the necessary options extracted from the connection string. For // TLS connections, this function will load the system certificates. -func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { +func (cb *Builder) Build(ctx context.Context) (*grpc.ClientConn, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithAuthority(cb.Host)) @@ -76,16 +78,16 @@ func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { } remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port) - conn, err := grpc.Dial(remote, opts...) + conn, err := grpc.DialContext(ctx, remote, opts...) if err != nil { - return nil, fmt.Errorf("failed to connect to remote %s: %w", remote, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to connect to remote %s: %w", remote, err), sparkerrors.ConnectionError) } return conn, nil } -// Creates a new instance of the ChannelBuilder. This constructor effectively +// NewBuilder creates a new instance of the Builder. This constructor effectively // parses the connection string and extracts the relevant parameters directly. -func NewBuilder(connection string) (*ChannelBuilder, error) { +func NewBuilder(connection string) (*Builder, error) { u, err := url.Parse(connection) if err != nil { @@ -93,7 +95,7 @@ func NewBuilder(connection string) (*ChannelBuilder, error) { } if u.Scheme != "sc" { - return nil, errors.New("URL schema must be set to `sc`.") + return nil, sparkerrors.WithType(errors.New("URL schema must be set to `sc`"), sparkerrors.InvalidInputError) } var port = 15002 @@ -115,10 +117,10 @@ func NewBuilder(connection string) (*ChannelBuilder, error) { // Validate that the URL path is empty or follows the right format. if u.Path != "" && !strings.HasPrefix(u.Path, "/;") { - return nil, fmt.Errorf("The URL path (%v) must be empty or have a proper parameter syntax.", u.Path) + return nil, sparkerrors.WithType(fmt.Errorf("the URL path (%v) must be empty or have a proper parameter syntax", u.Path), sparkerrors.InvalidInputError) } - cb := &ChannelBuilder{ + cb := &Builder{ Host: host, Port: port, Headers: map[string]string{}, diff --git a/client/channel/channel_test.go b/client/channel/channel_test.go index f4c4bc3..1395d99 100644 --- a/client/channel/channel_test.go +++ b/client/channel/channel_test.go @@ -17,10 +17,12 @@ package channel_test import ( + "context" "strings" "testing" "github.com/apache/spark-connect-go/v1/client/channel" + "github.com/apache/spark-connect-go/v1/client/sparkerrors" "github.com/stretchr/testify/assert" ) @@ -49,7 +51,8 @@ func TestBasicChannelParsing(t *testing.T) { assert.Nilf(t, err, "Port must be a valid number %v", err) _, err = channel.NewBuilder("sc://abcd/this") - assert.True(t, strings.Contains(err.Error(), "The URL path"), "URL path elements are not allowed") + assert.True(t, strings.Contains(err.Error(), "URL path"), "URL path elements are not allowed") + assert.ErrorIs(t, err, sparkerrors.InvalidInputError) cb, err = channel.NewBuilder(goodChannelURL) assert.Equal(t, "host", cb.Host) @@ -60,7 +63,7 @@ func TestBasicChannelParsing(t *testing.T) { assert.Equal(t, "b", cb.Token) cb, err = channel.NewBuilder("sc://localhost:443/;token=token;user_id=user_id;cluster_id=a") - assert.Nilf(t, err, "Unexpected error: %v", err) + assert.NoError(t, err) assert.Equal(t, 443, cb.Port) assert.Equal(t, "localhost", cb.Host) assert.Equal(t, "token", cb.Token) @@ -68,15 +71,16 @@ func TestBasicChannelParsing(t *testing.T) { } func TestChannelBuildConnect(t *testing.T) { + ctx := context.Background() cb, err := channel.NewBuilder("sc://localhost") assert.Nil(t, err, "Should not have an error for a proper URL.") - conn, err := cb.Build() + conn, err := cb.Build(ctx) assert.Nil(t, err, "no error for proper connection") assert.NotNil(t, conn) cb, err = channel.NewBuilder("sc://localhost:443/;token=abcd;user_id=a") assert.Nil(t, err, "Should not have an error for a proper URL.") - conn, err = cb.Build() + conn, err = cb.Build(ctx) assert.Nil(t, err, "no error for proper connection") assert.NotNil(t, conn) } diff --git a/client/channel/compat.go b/client/channel/compat.go new file mode 100644 index 0000000..33f8a3e --- /dev/null +++ b/client/channel/compat.go @@ -0,0 +1,6 @@ +package channel + +// ChannelBuilder re-exports Builder as its previous name for compatibility. +// +// Deprecated: use Builder instead. +type ChannelBuilder = Builder diff --git a/client/sql/row.go b/client/sparkerrors/errors.go similarity index 52% copy from client/sql/row.go copy to client/sparkerrors/errors.go index 3bee2ac..770cdef 100644 --- a/client/sql/row.go +++ b/client/sparkerrors/errors.go @@ -14,22 +14,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -package sql +package sparkerrors -type Row interface { - Schema() (*StructType, error) - Values() ([]any, error) +import ( + "errors" + "fmt" +) + +type wrappedError struct { + errorType error + cause error } -type GenericRowWithSchema struct { - values []any - schema *StructType +func (w *wrappedError) Unwrap() []error { + return []error{w.errorType, w.cause} } -func (r *GenericRowWithSchema) Schema() (*StructType, error) { - return r.schema, nil +func (w *wrappedError) Error() string { + return fmt.Sprintf("%s: %s", w.errorType, w.cause) } -func (r *GenericRowWithSchema) Values() ([]any, error) { - return r.values, nil +// WithType wraps an error with a type that can later be checked using `errors.Is` +func WithType(err error, errType errorType) error { + return &wrappedError{cause: err, errorType: errType} } + +type errorType error + +var ( + ConnectionError = errorType(errors.New("connection error")) + ReadError = errorType(errors.New("read error")) + ExecutionError = errorType(errors.New("execution error")) + InvalidInputError = errorType(errors.New("invalid input")) +) diff --git a/client/sparkerrors/errors_test.go b/client/sparkerrors/errors_test.go new file mode 100644 index 0000000..f5857ec --- /dev/null +++ b/client/sparkerrors/errors_test.go @@ -0,0 +1,17 @@ +package sparkerrors + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWithTypeGivesAndErrorThatIsOfThatType(t *testing.T) { + err := WithType(assert.AnError, ConnectionError) + assert.ErrorIs(t, err, ConnectionError) +} + +func TestErrorStringContainsErrorType(t *testing.T) { + err := WithType(assert.AnError, ConnectionError) + assert.Contains(t, err.Error(), ConnectionError.Error()) +} diff --git a/client/sql/dataframe.go b/client/sql/dataframe.go index 6fa7515..a7f5f51 100644 --- a/client/sql/dataframe.go +++ b/client/sql/dataframe.go @@ -18,27 +18,41 @@ package sql import ( "bytes" + "context" "errors" "fmt" + "io" + "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/ipc" + "github.com/apache/spark-connect-go/v1/client/sparkerrors" proto "github.com/apache/spark-connect-go/v1/internal/generated" - "io" ) +// ResultCollector receives a stream of result rows +type ResultCollector interface { + // WriteRow receives a single row from the data frame + WriteRow(values []any) +} + // DataFrame is a wrapper for data frame, representing a distributed collection of data row. type DataFrame interface { - // Show prints out data frame data. - Show(numRows int, truncate bool) error + // WriteResult streams the data frames to a result collector + WriteResult(ctx context.Context, collector ResultCollector, numRows int, truncate bool) error + // Show uses WriteResult to write the data frames to the console output. + Show(ctx context.Context, numRows int, truncate bool) error // Schema returns the schema for the current data frame. - Schema() (*StructType, error) + Schema(ctx context.Context) (*StructType, error) // Collect returns the data rows of the current data frame. - Collect() ([]Row, error) - // Write returns a data frame writer, which could be used to save data frame to supported storage. + Collect(ctx context.Context) ([]Row, error) + // Writer returns a data frame writer, which could be used to save data frame to supported storage. + Writer() DataFrameWriter + // Write is an alias for Writer + // Deprecated: Use Writer Write() DataFrameWriter // CreateTempView creates or replaces a temporary view. - CreateTempView(viewName string, replace bool, global bool) error + CreateTempView(ctx context.Context, viewName string, replace bool, global bool) error // Repartition re-partitions a data frame. Repartition(numPartitions int, columns []string) (DataFrame, error) // RepartitionByRange re-partitions a data frame by range partition. @@ -52,11 +66,29 @@ type RangePartitionColumn struct { // dataFrameImpl is an implementation of DataFrame interface. type dataFrameImpl struct { - sparkSession *sparkSessionImpl - relation *proto.Relation // TODO change to proto.Plan? + sparkExecutor sparkExecutor + relation *proto.Relation // TODO change to proto.Plan? +} + +func newDataFrame(sparkExecutor sparkExecutor, relation *proto.Relation) DataFrame { + return &dataFrameImpl{ + sparkExecutor: sparkExecutor, + relation: relation, + } +} + +type consoleCollector struct { +} + +func (c consoleCollector) WriteRow(values []any) { + fmt.Println(values...) } -func (df *dataFrameImpl) Show(numRows int, truncate bool) error { +func (df *dataFrameImpl) Show(ctx context.Context, numRows int, truncate bool) error { + return df.WriteResult(ctx, &consoleCollector{}, numRows, truncate) +} + +func (df *dataFrameImpl) WriteResult(ctx context.Context, collector ResultCollector, numRows int, truncate bool) error { truncateValue := 0 if truncate { truncateValue = 20 @@ -81,45 +113,42 @@ func (df *dataFrameImpl) Show(numRows int, truncate bool) error { }, } - responseClient, err := df.sparkSession.executePlan(plan) + responseClient, err := df.sparkExecutor.executePlan(ctx, plan) if err != nil { - return fmt.Errorf("failed to show dataframe: %w", err) + return sparkerrors.WithType(fmt.Errorf("failed to show dataframe: %w", err), sparkerrors.ExecutionError) } for { response, err := responseClient.Recv() if err != nil { - return fmt.Errorf("failed to receive show response: %w", err) + return sparkerrors.WithType(fmt.Errorf("failed to receive show response: %w", err), sparkerrors.ReadError) } arrowBatch := response.GetArrowBatch() if arrowBatch == nil { continue } - err = showArrowBatch(arrowBatch) + err = showArrowBatch(arrowBatch, collector) if err != nil { return err } return nil } - - return fmt.Errorf("did not get arrow batch in response") } -func (df *dataFrameImpl) Schema() (*StructType, error) { - response, err := df.sparkSession.analyzePlan(df.createPlan()) +func (df *dataFrameImpl) Schema(ctx context.Context) (*StructType, error) { + response, err := df.sparkExecutor.analyzePlan(ctx, df.createPlan()) if err != nil { - return nil, fmt.Errorf("failed to analyze plan: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to analyze plan: %w", err), sparkerrors.ExecutionError) } responseSchema := response.GetSchema().Schema - result := convertProtoDataTypeToStructType(responseSchema) - return result, nil + return convertProtoDataTypeToStructType(responseSchema) } -func (df *dataFrameImpl) Collect() ([]Row, error) { - responseClient, err := df.sparkSession.executePlan(df.createPlan()) +func (df *dataFrameImpl) Collect(ctx context.Context) ([]Row, error) { + responseClient, err := df.sparkExecutor.executePlan(ctx, df.createPlan()) if err != nil { - return nil, fmt.Errorf("failed to execute plan: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) } var schema *StructType @@ -131,13 +160,16 @@ func (df *dataFrameImpl) Collect() ([]Row, error) { if errors.Is(err, io.EOF) { return allRows, nil } else { - return nil, fmt.Errorf("failed to receive plan execution response: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) } } dataType := response.GetSchema() if dataType != nil { - schema = convertProtoDataTypeToStructType(dataType) + schema, err = convertProtoDataTypeToStructType(dataType) + if err != nil { + return nil, err + } continue } @@ -156,19 +188,17 @@ func (df *dataFrameImpl) Collect() ([]Row, error) { } allRows = append(allRows, rowBatch...) } - - return allRows, nil } func (df *dataFrameImpl) Write() DataFrameWriter { - writer := dataFrameWriterImpl{ - sparkSession: df.sparkSession, - relation: df.relation, - } - return &writer + return df.Writer() } -func (df *dataFrameImpl) CreateTempView(viewName string, replace bool, global bool) error { +func (df *dataFrameImpl) Writer() DataFrameWriter { + return newDataFrameWriter(df.sparkExecutor, df.relation) +} + +func (df *dataFrameImpl) CreateTempView(ctx context.Context, viewName string, replace bool, global bool) error { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ @@ -184,12 +214,12 @@ func (df *dataFrameImpl) CreateTempView(viewName string, replace bool, global bo }, } - responseClient, err := df.sparkSession.executePlan(plan) + responseClient, err := df.sparkExecutor.executePlan(ctx, plan) if err != nil { - return fmt.Errorf("failed to create temp view %s: %w", viewName, err) + return sparkerrors.WithType(fmt.Errorf("failed to create temp view %s: %w", viewName, err), sparkerrors.ExecutionError) } - return consumeExecutePlanClient(responseClient) + return responseClient.consumeAll() } func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) (DataFrame, error) { @@ -272,17 +302,14 @@ func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionEx }, }, } - return &dataFrameImpl{ - sparkSession: df.sparkSession, - relation: newRelation, - }, nil + return newDataFrame(df.sparkExecutor, newRelation), nil } -func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch) error { - return showArrowBatchData(arrowBatch.Data) +func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch, collector ResultCollector) error { + return showArrowBatchData(arrowBatch.Data, collector) } -func showArrowBatchData(data []byte) error { +func showArrowBatchData(data []byte, collector ResultCollector) error { rows, err := readArrowBatchData(data, nil) if err != nil { return err @@ -290,9 +317,9 @@ func showArrowBatchData(data []byte) error { for _, row := range rows { values, err := row.Values() if err != nil { - return fmt.Errorf("failed to get values in the row: %w", err) + return sparkerrors.WithType(fmt.Errorf("failed to get values in the row: %w", err), sparkerrors.ReadError) } - fmt.Println(values...) + collector.WriteRow(values) } return nil } @@ -301,7 +328,7 @@ func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) { reader := bytes.NewReader(data) arrowReader, err := ipc.NewReader(reader) if err != nil { - return nil, fmt.Errorf("failed to create arrow reader: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to create arrow reader: %w", err), sparkerrors.ReadError) } defer arrowReader.Release() @@ -313,7 +340,7 @@ func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) { if errors.Is(err, io.EOF) { return rows, nil } else { - return nil, fmt.Errorf("failed to read arrow: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to read arrow: %w", err), sparkerrors.ReadError) } } @@ -328,10 +355,7 @@ func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) { } for _, v := range values { - row := &GenericRowWithSchema{ - schema: schema, - values: v, - } + row := NewRowWithSchema(v, schema) rows = append(rows, row) } @@ -445,14 +469,14 @@ func readArrowRecordColumn(record arrow.Record, columnIndex int, values [][]any) return nil } -func convertProtoDataTypeToStructType(input *proto.DataType) *StructType { +func convertProtoDataTypeToStructType(input *proto.DataType) (*StructType, error) { dataTypeStruct := input.GetStruct() if dataTypeStruct == nil { - panic("dataType.GetStruct() is nil") + return nil, sparkerrors.WithType(errors.New("dataType.GetStruct() is nil"), sparkerrors.InvalidInputError) } return &StructType{ Fields: convertProtoStructFields(dataTypeStruct.Fields), - } + }, nil } func convertProtoStructFields(input []*proto.DataType_StructField) []StructField { diff --git a/client/sql/dataframe_test.go b/client/sql/dataframe_test.go index ac5b162..98675fd 100644 --- a/client/sql/dataframe_test.go +++ b/client/sql/dataframe_test.go @@ -18,6 +18,9 @@ package sql import ( "bytes" + "context" + "testing" + "github.com/apache/arrow/go/v12/arrow" "github.com/apache/arrow/go/v12/arrow/array" "github.com/apache/arrow/go/v12/arrow/decimal128" @@ -28,7 +31,6 @@ import ( proto "github.com/apache/spark-connect-go/v1/internal/generated" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "testing" ) func TestShowArrowBatchData(t *testing.T) { @@ -56,8 +58,10 @@ func TestShowArrowBatchData(t *testing.T) { err := arrowWriter.Write(record) require.Nil(t, err) - err = showArrowBatchData(buf.Bytes()) + collector := &testCollector{} + err = showArrowBatchData(buf.Bytes(), collector) assert.Nil(t, err) + assert.Equal(t, []any{"str2"}, collector.row) } func TestReadArrowRecord(t *testing.T) { @@ -304,3 +308,73 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { } assert.Equal(t, "Unsupported", convertProtoDataTypeToDataType(unsupportedDataType).TypeName()) } + +type testCollector struct { + row []any +} + +func (t *testCollector) WriteRow(values []any) { + t.row = values +} + +func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { + ctx := context.Background() + + arrowFields := []arrow.Field{ + { + Name: "show_string", + Type: &arrow.StringType{}, + }, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b") + recordBuilder.Field(0).(*array.StringBuilder).Append("str2") + + record := recordBuilder.NewRecord() + defer record.Release() + + err := arrowWriter.Write(record) + require.Nil(t, err) + + query := "select * from bla" + + session := &sparkSessionImpl{ + client: &connectServiceClient{ + executePlanClient: &executePlanClient{&protoClient{ + recvResponses: []*proto.ExecutePlanResponse{ + { + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, + }, + }, + { + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + RowCount: 1, + Data: buf.Bytes(), + }, + }, + }, + }}, + }, + t: t, + }, + } + resp, err := session.Sql(ctx, query) + assert.NoError(t, err) + assert.NotNil(t, resp) + writer, err := resp.Repartition(1, []string{"1"}) + assert.NoError(t, err) + collector := &testCollector{} + err = writer.WriteResult(ctx, collector, 1, false) + assert.NoError(t, err) + assert.Equal(t, []any{"str2"}, collector.row) +} diff --git a/client/sql/dataframereader.go b/client/sql/dataframereader.go index 2f0e215..59f21ff 100644 --- a/client/sql/dataframereader.go +++ b/client/sql/dataframereader.go @@ -14,34 +14,40 @@ type DataFrameReader interface { // dataFrameReaderImpl is an implementation of DataFrameReader interface. type dataFrameReaderImpl struct { - sparkSession *sparkSessionImpl + sparkSession sparkExecutor formatSource string } +func newDataframeReader(session sparkExecutor) DataFrameReader { + return &dataFrameReaderImpl{ + sparkSession: session, + } +} + func (w *dataFrameReaderImpl) Format(source string) DataFrameReader { w.formatSource = source return w } func (w *dataFrameReaderImpl) Load(path string) (DataFrame, error) { - var format *string + var format string if w.formatSource != "" { - format = &w.formatSource + format = w.formatSource } - df := &dataFrameImpl{ - sparkSession: w.sparkSession, - relation: &proto.Relation{ - RelType: &proto.Relation_Read{ - Read: &proto.Read{ - ReadType: &proto.Read_DataSource_{ - DataSource: &proto.Read_DataSource{ - Format: format, - Paths: []string{path}, - }, + return newDataFrame(w.sparkSession, toRelation(path, format)), nil +} + +func toRelation(path string, format string) *proto.Relation { + return &proto.Relation{ + RelType: &proto.Relation_Read{ + Read: &proto.Read{ + ReadType: &proto.Read_DataSource_{ + DataSource: &proto.Read_DataSource{ + Format: &format, + Paths: []string{path}, }, }, }, }, } - return df, nil } diff --git a/client/sql/dataframereader_test.go b/client/sql/dataframereader_test.go new file mode 100644 index 0000000..c5e73c0 --- /dev/null +++ b/client/sql/dataframereader_test.go @@ -0,0 +1,26 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoadCreatesADataFrame(t *testing.T) { + reader := newDataframeReader(nil) + source := "source" + path := "path" + reader.Format(source) + frame, err := reader.Load(path) + assert.NoError(t, err) + assert.NotNil(t, frame) +} + +func TestRelationContainsPathAndFormat(t *testing.T) { + formatSource := "source" + path := "path" + relation := toRelation(path, formatSource) + assert.NotNil(t, relation) + assert.Equal(t, &formatSource, relation.GetRead().GetDataSource().Format) + assert.Equal(t, path, relation.GetRead().GetDataSource().Paths[0]) +} diff --git a/client/sql/dataframewriter.go b/client/sql/dataframewriter.go index a398ebd..dee74e8 100644 --- a/client/sql/dataframewriter.go +++ b/client/sql/dataframewriter.go @@ -1,9 +1,12 @@ package sql import ( + "context" "fmt" - proto "github.com/apache/spark-connect-go/v1/internal/generated" "strings" + + "github.com/apache/spark-connect-go/v1/client/sparkerrors" + proto "github.com/apache/spark-connect-go/v1/internal/generated" ) // DataFrameWriter supports writing data frame to storage. @@ -13,15 +16,27 @@ type DataFrameWriter interface { // Format specifies data format (data source type) for the underlying data, e.g. parquet. Format(source string) DataFrameWriter // Save writes data frame to the given path. - Save(path string) error + Save(ctx context.Context, path string) error +} + +type sparkExecutor interface { + executePlan(ctx context.Context, plan *proto.Plan) (*executePlanClient, error) + analyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) +} + +func newDataFrameWriter(sparkExecutor sparkExecutor, relation *proto.Relation) DataFrameWriter { + return &dataFrameWriterImpl{ + sparkExecutor: sparkExecutor, + relation: relation, + } } // dataFrameWriterImpl is an implementation of DataFrameWriter interface. type dataFrameWriterImpl struct { - sparkSession *sparkSessionImpl - relation *proto.Relation - saveMode string - formatSource string + sparkExecutor sparkExecutor + relation *proto.Relation + saveMode string + formatSource string } func (w *dataFrameWriterImpl) Mode(saveMode string) DataFrameWriter { @@ -34,7 +49,7 @@ func (w *dataFrameWriterImpl) Format(source string) DataFrameWriter { return w } -func (w *dataFrameWriterImpl) Save(path string) error { +func (w *dataFrameWriterImpl) Save(ctx context.Context, path string) error { saveMode, err := getSaveMode(w.saveMode) if err != nil { return err @@ -59,12 +74,12 @@ func (w *dataFrameWriterImpl) Save(path string) error { }, }, } - responseClient, err := w.sparkSession.executePlan(plan) + responseClient, err := w.sparkExecutor.executePlan(ctx, plan) if err != nil { return err } - return consumeExecutePlanClient(responseClient) + return responseClient.consumeAll() } func getSaveMode(mode string) (proto.WriteOperation_SaveMode, error) { @@ -79,6 +94,6 @@ func getSaveMode(mode string) (proto.WriteOperation_SaveMode, error) { } else if strings.EqualFold(mode, "Ignore") { return proto.WriteOperation_SAVE_MODE_IGNORE, nil } else { - return 0, fmt.Errorf("unsupported save mode: %s", mode) + return 0, sparkerrors.WithType(fmt.Errorf("unsupported save mode: %s", mode), sparkerrors.InvalidInputError) } } diff --git a/client/sql/dataframewriter_test.go b/client/sql/dataframewriter_test.go index d61266f..649e013 100644 --- a/client/sql/dataframewriter_test.go +++ b/client/sql/dataframewriter_test.go @@ -1,9 +1,12 @@ package sql import ( + "context" + "io" + "testing" + proto "github.com/apache/spark-connect-go/v1/internal/generated" "github.com/stretchr/testify/assert" - "testing" ) func TestGetSaveMode(t *testing.T) { @@ -31,3 +34,53 @@ func TestGetSaveMode(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, proto.WriteOperation_SAVE_MODE_UNSPECIFIED, mode) } + +func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) { + relation := &proto.Relation{} + executor := &testExecutor{ + client: newExecutePlanClient(&protoClient{ + err: io.EOF, + }), + } + ctx := context.Background() + path := "path" + + writer := newDataFrameWriter(executor, relation) + writer.Format("format") + writer.Mode("append") + err := writer.Save(ctx, path) + assert.NoError(t, err) +} + +func TestSaveFailsIfAnotherErrorHappensWhenReadingStream(t *testing.T) { + relation := &proto.Relation{} + executor := &testExecutor{ + client: newExecutePlanClient(&protoClient{ + err: assert.AnError, + }), + } + ctx := context.Background() + path := "path" + + writer := newDataFrameWriter(executor, relation) + writer.Format("format") + writer.Mode("append") + err := writer.Save(ctx, path) + assert.Error(t, err) +} + +func TestSaveFailsIfAnotherErrorHappensWhenExecuting(t *testing.T) { + relation := &proto.Relation{} + executor := &testExecutor{ + client: newExecutePlanClient(&protoClient{}), + err: assert.AnError, + } + ctx := context.Background() + path := "path" + + writer := newDataFrameWriter(executor, relation) + writer.Format("format") + writer.Mode("append") + err := writer.Save(ctx, path) + assert.Error(t, err) +} diff --git a/client/sql/datatype.go b/client/sql/datatype.go index e201114..ecbadf3 100644 --- a/client/sql/datatype.go +++ b/client/sql/datatype.go @@ -17,7 +17,7 @@ package sql import ( - "reflect" + "fmt" "strings" ) @@ -125,16 +125,7 @@ func (t UnsupportedType) TypeName() string { } func getDataTypeName(dataType DataType) string { - t := reflect.TypeOf(dataType) - if t == nil { - return "(nil)" - } - var name string - if t.Kind() == reflect.Ptr { - name = t.Elem().Name() - } else { - name = t.Name() - } - name = strings.TrimSuffix(name, "Type") - return name + typeName := fmt.Sprintf("%T", dataType) + nonQualifiedTypeName := strings.Split(typeName, ".")[1] + return strings.TrimSuffix(nonQualifiedTypeName, "Type") } diff --git a/client/sql/executeplanclient.go b/client/sql/executeplanclient.go new file mode 100644 index 0000000..ece0a20 --- /dev/null +++ b/client/sql/executeplanclient.go @@ -0,0 +1,37 @@ +package sql + +import ( + "errors" + "fmt" + "io" + + "github.com/apache/spark-connect-go/v1/client/sparkerrors" + proto "github.com/apache/spark-connect-go/v1/internal/generated" +) + +type executePlanClient struct { + proto.SparkConnectService_ExecutePlanClient +} + +func newExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) *executePlanClient { + return &executePlanClient{ + responseClient, + } +} + +// consumeAll reads through the returned GRPC stream from Spark Connect Driver. It will +// discard the returned data if there is no error. This is necessary for handling GRPC response for +// saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame. +// If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. +func (c *executePlanClient) consumeAll() error { + for { + _, err := c.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } else { + return sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: %w", err), sparkerrors.ReadError) + } + } + } +} diff --git a/client/sql/mocks_test.go b/client/sql/mocks_test.go new file mode 100644 index 0000000..9d87d50 --- /dev/null +++ b/client/sql/mocks_test.go @@ -0,0 +1,112 @@ +package sql + +import ( + "context" + "testing" + + proto "github.com/apache/spark-connect-go/v1/internal/generated" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type testExecutor struct { + client *executePlanClient + response *proto.AnalyzePlanResponse + err error +} + +func (t *testExecutor) executePlan(ctx context.Context, plan *proto.Plan) (*executePlanClient, error) { + if t.err != nil { + return nil, t.err + } + return t.client, nil +} + +func (t *testExecutor) analyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { + return t.response, nil +} + +type protoClient struct { + recvResponse *proto.ExecutePlanResponse + recvResponses []*proto.ExecutePlanResponse + + err error +} + +func (p *protoClient) Recv() (*proto.ExecutePlanResponse, error) { + if len(p.recvResponses) != 0 { + p.recvResponse = p.recvResponses[0] + p.recvResponses = p.recvResponses[1:] + } + return p.recvResponse, p.err +} + +func (p *protoClient) Header() (metadata.MD, error) { + return nil, p.err +} + +func (p *protoClient) Trailer() metadata.MD { + return nil +} + +func (p *protoClient) CloseSend() error { + return p.err +} + +func (p *protoClient) Context() context.Context { + return nil +} + +func (p *protoClient) SendMsg(m interface{}) error { + return p.err +} + +func (p *protoClient) RecvMsg(m interface{}) error { + return p.err +} + +type connectServiceClient struct { + t *testing.T + + analysePlanResponse *proto.AnalyzePlanResponse + executePlanClient proto.SparkConnectService_ExecutePlanClient + expectedExecutePlanRequest *proto.ExecutePlanRequest + + err error +} + +func (c *connectServiceClient) ExecutePlan(ctx context.Context, in *proto.ExecutePlanRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ExecutePlanClient, error) { + if c.expectedExecutePlanRequest != nil { + assert.Equal(c.t, c.expectedExecutePlanRequest, in) + } + return c.executePlanClient, c.err +} + +func (c *connectServiceClient) AnalyzePlan(ctx context.Context, in *proto.AnalyzePlanRequest, opts ...grpc.CallOption) (*proto.AnalyzePlanResponse, error) { + return c.analysePlanResponse, c.err +} + +func (c *connectServiceClient) Config(ctx context.Context, in *proto.ConfigRequest, opts ...grpc.CallOption) (*proto.ConfigResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) AddArtifacts(ctx context.Context, opts ...grpc.CallOption) (proto.SparkConnectService_AddArtifactsClient, error) { + return nil, c.err +} + +func (c *connectServiceClient) ArtifactStatus(ctx context.Context, in *proto.ArtifactStatusesRequest, opts ...grpc.CallOption) (*proto.ArtifactStatusesResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) Interrupt(ctx context.Context, in *proto.InterruptRequest, opts ...grpc.CallOption) (*proto.InterruptResponse, error) { + return nil, c.err +} + +func (c *connectServiceClient) ReattachExecute(ctx context.Context, in *proto.ReattachExecuteRequest, opts ...grpc.CallOption) (proto.SparkConnectService_ReattachExecuteClient, error) { + return nil, c.err +} + +func (c *connectServiceClient) ReleaseExecute(ctx context.Context, in *proto.ReleaseExecuteRequest, opts ...grpc.CallOption) (*proto.ReleaseExecuteResponse, error) { + return nil, c.err +} diff --git a/client/sql/plan_test.go b/client/sql/plan_test.go new file mode 100644 index 0000000..c733862 --- /dev/null +++ b/client/sql/plan_test.go @@ -0,0 +1,13 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewPlanIdGivesNewIDs(t *testing.T) { + id1 := newPlanId() + id2 := newPlanId() + assert.NotEqual(t, id1, id2) +} diff --git a/client/sql/row.go b/client/sql/row.go index 3bee2ac..bea2ab7 100644 --- a/client/sql/row.go +++ b/client/sql/row.go @@ -16,20 +16,31 @@ package sql +// Row represents a row in a DataFrame. type Row interface { + // Schema returns the schema of the row. Schema() (*StructType, error) + // Values returns the values of the row. Values() ([]any, error) } -type GenericRowWithSchema struct { +// genericRowWithSchema represents a row in a DataFrame with schema. +type genericRowWithSchema struct { values []any schema *StructType } -func (r *GenericRowWithSchema) Schema() (*StructType, error) { +func NewRowWithSchema(values []any, schema *StructType) Row { + return &genericRowWithSchema{ + values: values, + schema: schema, + } +} + +func (r *genericRowWithSchema) Schema() (*StructType, error) { return r.schema, nil } -func (r *GenericRowWithSchema) Values() ([]any, error) { +func (r *genericRowWithSchema) Values() ([]any, error) { return r.values, nil } diff --git a/client/sql/row_test.go b/client/sql/row_test.go new file mode 100644 index 0000000..7ae4f97 --- /dev/null +++ b/client/sql/row_test.go @@ -0,0 +1,25 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSchema(t *testing.T) { + values := []any{1} + schema := &StructType{} + row := NewRowWithSchema(values, schema) + schema2, err := row.Schema() + assert.NoError(t, err) + assert.Equal(t, schema, schema2) +} + +func TestValues(t *testing.T) { + values := []any{1} + schema := &StructType{} + row := NewRowWithSchema(values, schema) + values2, err := row.Values() + assert.NoError(t, err) + assert.Equal(t, values, values2) +} diff --git a/client/sql/sparksession.go b/client/sql/sparksession.go index a5ac466..e1ecef8 100644 --- a/client/sql/sparksession.go +++ b/client/sql/sparksession.go @@ -18,48 +18,46 @@ package sql import ( "context" - "errors" "fmt" "github.com/apache/spark-connect-go/v1/client/channel" + "github.com/apache/spark-connect-go/v1/client/sparkerrors" proto "github.com/apache/spark-connect-go/v1/internal/generated" "github.com/google/uuid" "google.golang.org/grpc/metadata" - "io" ) -var SparkSession sparkSessionBuilderEntrypoint - -type sparkSession interface { +type SparkSession interface { Read() DataFrameReader - Sql(query string) (DataFrame, error) + Sql(ctx context.Context, query string) (DataFrame, error) Stop() error } -type sparkSessionBuilderEntrypoint struct { - Builder SparkSessionBuilder +// NewSessionBuilder creates a new session builder for starting a new spark session +func NewSessionBuilder() *SparkSessionBuilder { + return &SparkSessionBuilder{} } type SparkSessionBuilder struct { connectionString string } -func (s SparkSessionBuilder) Remote(connectionString string) SparkSessionBuilder { - copy := s - copy.connectionString = connectionString - return copy +// Remote sets the connection string for remote connection +func (s *SparkSessionBuilder) Remote(connectionString string) *SparkSessionBuilder { + s.connectionString = connectionString + return s } -func (s SparkSessionBuilder) Build() (sparkSession, error) { +func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { cb, err := channel.NewBuilder(s.connectionString) if err != nil { - return nil, fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError) } - conn, err := cb.Build() + conn, err := cb.Build(ctx) if err != nil { - return nil, fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to connect to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError) } // Add metadata to the request. @@ -83,12 +81,10 @@ type sparkSessionImpl struct { } func (s *sparkSessionImpl) Read() DataFrameReader { - return &dataFrameReaderImpl{ - sparkSession: s, - } + return newDataframeReader(s) } -func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) { +func (s *sparkSessionImpl) Sql(ctx context.Context, query string) (DataFrame, error) { plan := &proto.Plan{ OpType: &proto.Plan_Command{ Command: &proto.Command{ @@ -100,32 +96,28 @@ func (s *sparkSessionImpl) Sql(query string) (DataFrame, error) { }, }, } - responseClient, err := s.executePlan(plan) + responseClient, err := s.executePlan(ctx, plan) if err != nil { - return nil, fmt.Errorf("failed to execute sql: %s: %w", query, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to execute sql: %s: %w", query, err), sparkerrors.ExecutionError) } for { response, err := responseClient.Recv() if err != nil { - return nil, fmt.Errorf("failed to receive ExecutePlan response: %w", err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to receive ExecutePlan response: %w", err), sparkerrors.ReadError) } sqlCommandResult := response.GetSqlCommandResult() if sqlCommandResult == nil { continue } - return &dataFrameImpl{ - sparkSession: s, - relation: sqlCommandResult.GetRelation(), - }, nil + return newDataFrame(s, sqlCommandResult.GetRelation()), nil } - return nil, fmt.Errorf("failed to get SqlCommandResult in ExecutePlan response") } func (s *sparkSessionImpl) Stop() error { return nil } -func (s *sparkSessionImpl) executePlan(plan *proto.Plan) (proto.SparkConnectService_ExecutePlanClient, error) { +func (s *sparkSessionImpl) executePlan(ctx context.Context, plan *proto.Plan) (*executePlanClient, error) { request := proto.ExecutePlanRequest{ SessionId: s.sessionId, Plan: plan, @@ -134,15 +126,15 @@ func (s *sparkSessionImpl) executePlan(plan *proto.Plan) (proto.SparkConnectServ }, } // Append the other items to the request. - ctx := metadata.NewOutgoingContext(context.Background(), s.metadata) - executePlanClient, err := s.client.ExecutePlan(ctx, &request) + ctx = metadata.NewOutgoingContext(ctx, s.metadata) + client, err := s.client.ExecutePlan(ctx, &request) if err != nil { - return nil, fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to call ExecutePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) } - return executePlanClient, nil + return newExecutePlanClient(client), nil } -func (s *sparkSessionImpl) analyzePlan(plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { +func (s *sparkSessionImpl) analyzePlan(ctx context.Context, plan *proto.Plan) (*proto.AnalyzePlanResponse, error) { request := proto.AnalyzePlanRequest{ SessionId: s.sessionId, Analyze: &proto.AnalyzePlanRequest_Schema_{ @@ -155,29 +147,11 @@ func (s *sparkSessionImpl) analyzePlan(plan *proto.Plan) (*proto.AnalyzePlanResp }, } // Append the other items to the request. - ctx := metadata.NewOutgoingContext(context.Background(), s.metadata) + ctx = metadata.NewOutgoingContext(ctx, s.metadata) response, err := s.client.AnalyzePlan(ctx, &request) if err != nil { - return nil, fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err) + return nil, sparkerrors.WithType(fmt.Errorf("failed to call AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError) } return response, nil } - -// consumeExecutePlanClient reads through the returned GRPC stream from Spark Connect Driver. It will -// discard the returned data if there is no error. This is necessary for handling GRPC response for -// saving data frame, since such consuming will trigger Spark Connect Driver really saving data frame. -// If we do not consume the returned GRPC stream, Spark Connect Driver will not really save data frame. -func consumeExecutePlanClient(responseClient proto.SparkConnectService_ExecutePlanClient) error { - for { - _, err := responseClient.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - return nil - } else { - return fmt.Errorf("failed to receive plan execution response: %w", err) - } - } - } - return nil -} diff --git a/client/sql/sparksession_test.go b/client/sql/sparksession_test.go new file mode 100644 index 0000000..0a3b105 --- /dev/null +++ b/client/sql/sparksession_test.go @@ -0,0 +1,111 @@ +package sql + +import ( + "context" + "testing" + + "github.com/apache/spark-connect-go/v1/client/sparkerrors" + proto "github.com/apache/spark-connect-go/v1/internal/generated" + "github.com/stretchr/testify/assert" +) + +func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { + ctx := context.Background() + reponse := &proto.AnalyzePlanResponse{} + session := &sparkSessionImpl{ + client: &connectServiceClient{ + analysePlanResponse: reponse, + }, + } + resp, err := session.analyzePlan(ctx, &proto.Plan{}) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestAnalyzePlanFailsIfClientFails(t *testing.T) { + ctx := context.Background() + session := &sparkSessionImpl{ + client: &connectServiceClient{ + err: assert.AnError, + }, + } + resp, err := session.analyzePlan(ctx, &proto.Plan{}) + assert.Nil(t, resp) + assert.Error(t, err) +} + +func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { + ctx := context.Background() + + plan := &proto.Plan{} + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + session := &sparkSessionImpl{ + client: &connectServiceClient{ + executePlanClient: &executePlanClient{}, + expectedExecutePlanRequest: request, + t: t, + }, + } + resp, err := session.executePlan(ctx, plan) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) { + ctx := context.Background() + + query := "select * from bla" + plan := &proto.Plan{ + OpType: &proto.Plan_Command{ + Command: &proto.Command{ + CommandType: &proto.Command_SqlCommand{ + SqlCommand: &proto.SqlCommand{ + Sql: query, + }, + }, + }, + }, + } + request := &proto.ExecutePlanRequest{ + Plan: plan, + UserContext: &proto.UserContext{ + UserId: "na", + }, + } + session := &sparkSessionImpl{ + client: &connectServiceClient{ + executePlanClient: &executePlanClient{&protoClient{ + recvResponse: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{}, + }, + }, + }}, + expectedExecutePlanRequest: request, + t: t, + }, + } + resp, err := session.Sql(ctx, query) + assert.NoError(t, err) + assert.NotNil(t, resp) +} + +func TestNewSessionBuilderCreatesASession(t *testing.T) { + ctx := context.Background() + spark, err := NewSessionBuilder().Remote("sc:connection").Build(ctx) + assert.NoError(t, err) + assert.NotNil(t, spark) +} + +func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) { + ctx := context.Background() + spark, err := NewSessionBuilder().Remote("invalid").Build(ctx) + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidInputError) + assert.Nil(t, spark) +} diff --git a/client/sql/structtype.go b/client/sql/structtype.go index 2a59c30..fd75236 100644 --- a/client/sql/structtype.go +++ b/client/sql/structtype.go @@ -16,12 +16,14 @@ package sql +// StructField represents a field in a StructType. type StructField struct { Name string DataType DataType Nullable bool // default should be true } +// StructType represents a struct type. type StructType struct { TypeName string Fields []StructField diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go b/cmd/spark-connect-example-raw-grpc-client/main.go index f48500b..25e48d9 100644 --- a/cmd/spark-connect-example-raw-grpc-client/main.go +++ b/cmd/spark-connect-example-raw-grpc-client/main.go @@ -19,12 +19,13 @@ package main import ( "context" "flag" + "log" + "time" + proto "github.com/apache/spark-connect-go/v1/internal/generated" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "log" - "time" ) var ( @@ -32,13 +33,14 @@ var ( ) func main() { + ctx := context.Background() opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), } - conn, err := grpc.Dial(*remote, opts...) + conn, err := grpc.DialContext(ctx, *remote, opts...) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } defer conn.Close() @@ -57,7 +59,7 @@ func main() { } configResponse, err := client.Config(ctx, &configRequest) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("configResponse: %v", configResponse) diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index aeb95b8..f19a199 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -17,6 +17,7 @@ package main import ( + "context" "flag" "log" @@ -30,40 +31,41 @@ var ( func main() { flag.Parse() - spark, err := sql.SparkSession.Builder.Remote(*remote).Build() + ctx := context.Background() + spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } defer spark.Stop() - df, err := spark.Sql("select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") + df, err := spark.Sql(ctx, "select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from sql: select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") - err = df.Show(100, false) + err = df.Show(ctx, 100, false) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - schema, err := df.Schema() + schema, err := df.Schema(ctx) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } for _, f := range schema.Fields { log.Printf("Field in dataframe schema: %s - %s", f.Name, f.DataType.TypeName()) } - rows, err := df.Collect() + rows, err := df.Collect(ctx) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } schema, err = rows[0].Schema() if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } for _, f := range schema.Fields { @@ -74,72 +76,78 @@ func main() { log.Printf("Row: %v", row) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } df, err = spark.Read().Format("parquet"). Load("file:///tmp/spark-connect-write-example-output.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from reading parquet") - df.Show(100, false) + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } - err = df.CreateTempView("view1", true, false) + err = df.CreateTempView(ctx, "view1", true, false) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - df, err = spark.Sql("select count, word from view1 order by count") + df, err = spark.Sql(ctx, "select count, word from view1 order by count") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from sql: select count, word from view1 order by count") - df.Show(100, false) + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } log.Printf("Repartition with one partition") df, err = df.Repartition(1, nil) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output-one-partition.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output-one-partition.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("Repartition with two partitions") df, err = df.Repartition(2, nil) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output-two-partition.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output-two-partition.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("Repartition with columns") df, err = df.Repartition(0, []string{"word", "count"}) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output-repartition-with-column.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output-repartition-with-column.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("Repartition by range with columns") @@ -150,13 +158,13 @@ func main() { }, }) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output-repartition-by-range-with-column.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output-repartition-by-range-with-column.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } } diff --git a/go.mod b/go.mod index 72f627c..ffdb132 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ module github.com/apache/spark-connect-go/v1 -go 1.19 +go 1.21 require ( github.com/apache/arrow/go/v12 v12.0.0 diff --git a/quick-start.md b/quick-start.md index 8e1b1e1..2cf83a4 100644 --- a/quick-start.md +++ b/quick-start.md @@ -18,6 +18,7 @@ Create `main.go` file with following code: package main import ( + "context" "flag" "log" @@ -31,56 +32,78 @@ var ( func main() { flag.Parse() - spark, err := sql.SparkSession.Builder.Remote(*remote).Build() + ctx := context.Background() + spark, err := sql.NewSessionBuilder().Remote(*remote).Build(ctx) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } defer spark.Stop() - df, err := spark.Sql("select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") + df, err := spark.Sql(ctx, "select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from sql: select 'apple' as word, 123 as count union all select 'orange' as word, 456 as count") - err = df.Show(100, false) + err = df.Show(ctx, 100, false) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - rows, err := df.Collect() + schema, err := df.Schema(ctx) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) + } + + for _, f := range schema.Fields { + log.Printf("Field in dataframe schema: %s - %s", f.Name, f.DataType.TypeName()) + } + + rows, err := df.Collect(ctx) + if err != nil { + log.Fatalf("Failed: %s", err) + } + + schema, err = rows[0].Schema() + if err != nil { + log.Fatalf("Failed: %s", err) + } + + for _, f := range schema.Fields { + log.Printf("Field in row: %s - %s", f.Name, f.DataType.TypeName()) } for _, row := range rows { log.Printf("Row: %v", row) } - err = df.Write().Mode("overwrite"). + err = df.Writer().Mode("overwrite"). Format("parquet"). - Save("file:///tmp/spark-connect-write-example-output.parquet") + Save(ctx, "file:///tmp/spark-connect-write-example-output.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } df, err = spark.Read().Format("parquet"). Load("file:///tmp/spark-connect-write-example-output.parquet") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from reading parquet") - df.Show(100, false) + err = df.Show(ctx, 100, false) + if err != nil { + log.Fatalf("Failed: %s", err) + } - err = df.CreateTempView("view1", true, false) + err = df.CreateTempView(ctx, "view1", true, false) if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } - df, err = spark.Sql("select count, word from view1 order by count") + df, err = spark.Sql(ctx, "select count, word from view1 order by count") if err != nil { - log.Fatalf("Failed: %s", err.Error()) + log.Fatalf("Failed: %s", err) } log.Printf("DataFrame from sql: select count, word from view1 order by count") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org