This is an automated email from the ASF dual-hosted git repository.

mgrund pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 957a4b3  [SPARK-48982] Properly extract Spark Errors from GRPC Request
957a4b3 is described below

commit 957a4b3476bfe748394a3e0a106311605e216288
Author: Martin Grund <[email protected]>
AuthorDate: Wed Jul 24 15:30:17 2024 +0200

    [SPARK-48982] Properly extract Spark Errors from GRPC Request
    
    ### What changes were proposed in this pull request?
    Properly extract the Spark Exception information from the GRPC status and 
the associated error info details. The errors are then associated with a 
`SparkError` type that allows accessing the metadata.
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added new UT.
    
    Closes #36 from grundprinzip/SPARK-48982.
    
    Authored-by: Martin Grund <[email protected]>
    Signed-off-by: Martin Grund <[email protected]>
---
 cmd/spark-connect-example-spark-session/main.go |  46 ++++----
 spark/client/client.go                          |  14 ++-
 spark/sparkerrors/errors.go                     |  73 ++++++++++++
 spark/sparkerrors/errors_test.go                | 149 ++++++++++++++++++++++++
 4 files changed, 255 insertions(+), 27 deletions(-)

diff --git a/cmd/spark-connect-example-spark-session/main.go 
b/cmd/spark-connect-example-spark-session/main.go
index 5f63bcc..71f6f07 100644
--- a/cmd/spark-connect-example-spark-session/main.go
+++ b/cmd/spark-connect-example-spark-session/main.go
@@ -39,29 +39,29 @@ func main() {
        }
        defer utils.WarnOnError(spark.Stop, func(err error) {})
 
-       //df, err := spark.Sql(ctx, "select * from range(100)")
-       //if err != nil {
-       //      log.Fatalf("Failed: %s", err)
-       //}
-       //
-       //df, _ = df.FilterByString("id < 10")
-       //err = df.Show(ctx, 100, false)
-       //if err != nil {
-       //      log.Fatalf("Failed: %s", err)
-       //}
-       //
-       //df, err = spark.Sql(ctx, "select * from range(100)")
-       //if err != nil {
-       //      log.Fatalf("Failed: %s", err)
-       //}
-       //
-       //df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10")))
-       //err = df.Show(ctx, 100, false)
-       //if err != nil {
-       //      log.Fatalf("Failed: %s", err)
-       //}
-
-       df, _ := spark.Sql(ctx, "select * from range(100)")
+       df, err := spark.Sql(ctx, "select id2 from range(100)")
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       df, _ = df.FilterByString("id < 10")
+       err = df.Show(ctx, 100, false)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       df, err = spark.Sql(ctx, "select * from range(100)")
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       df, _ = df.Filter(functions.Col("id").Lt(functions.Expr("10")))
+       err = df.Show(ctx, 100, false)
+       if err != nil {
+               log.Fatalf("Failed: %s", err)
+       }
+
+       df, _ = spark.Sql(ctx, "select * from range(100)")
        df, err = df.Filter(functions.Col("id").Lt(functions.Lit(20)))
        if err != nil {
                log.Fatalf("Failed: %s", err)
diff --git a/spark/client/client.go b/spark/client/client.go
index ed65f44..a9b163e 100644
--- a/spark/client/client.go
+++ b/spark/client/client.go
@@ -111,8 +111,8 @@ func (s *SparkExecutorImpl) AnalyzePlan(ctx 
context.Context, plan *proto.Plan) (
        ctx = metadata.NewOutgoingContext(ctx, s.metadata)
 
        response, err := s.client.AnalyzePlan(ctx, &request)
-       if err != nil {
-               return nil, sparkerrors.WithType(fmt.Errorf("failed to call 
AnalyzePlan in session %s: %w", s.sessionId, err), sparkerrors.ExecutionError)
+       if se := sparkerrors.FromRPCError(err); se != nil {
+               return nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)
        }
        return response, nil
 }
@@ -126,6 +126,8 @@ func NewSparkExecutor(conn *grpc.ClientConn, md 
metadata.MD, sessionId string) S
        }
 }
 
+// NewSparkExecutorFromClient creates a new SparkExecutor from an existing 
client and is mostly
+// used in testing.
 func NewSparkExecutorFromClient(client proto.SparkConnectServiceClient, md 
metadata.MD, sessionId string) SparkExecutor {
        return &SparkExecutorImpl{
                client:    client,
@@ -156,11 +158,15 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, 
arrow.Table, error) {
 
        for {
                resp, err := c.responseStream.Recv()
+               // EOF is received when the last message has been processed and 
the stream
+               // finished normally.
                if err == io.EOF {
                        break
                }
-               if err != nil {
-                       return nil, nil, 
sparkerrors.WithType(fmt.Errorf("failed to receive plan execution response: 
%w", err), sparkerrors.ReadError)
+
+               // If the error was not EOF, there might be another error.
+               if se := sparkerrors.FromRPCError(err); se != nil {
+                       return nil, nil, sparkerrors.WithType(se, 
sparkerrors.ExecutionError)
                }
 
                // Process the message
diff --git a/spark/sparkerrors/errors.go b/spark/sparkerrors/errors.go
index 030db86..eecd34d 100644
--- a/spark/sparkerrors/errors.go
+++ b/spark/sparkerrors/errors.go
@@ -17,8 +17,13 @@
 package sparkerrors
 
 import (
+       "encoding/json"
        "errors"
        "fmt"
+
+       "google.golang.org/genproto/googleapis/rpc/errdetails"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
 )
 
 type wrappedError struct {
@@ -65,3 +70,71 @@ type InvalidServerSideSessionError struct {
 func (e InvalidServerSideSessionError) Error() string {
        return fmt.Sprintf("Received invalid session id %s, expected %s", 
e.ReceivedSessionId, e.OwnSessionId)
 }
+
+// SparkError represents an error that is returned from Spark itself. It 
captures details of the
+// error that allows better understanding about the error. This allows us to 
check if the error
+// can be retried or not.
+type SparkError struct {
+       // SqlState is the SQL state of the error.
+       SqlState string
+       // ErrorClass is the class of the error.
+       ErrorClass string
+       // If set is typically the classname throwing the error on the Spark 
side.
+       Reason string
+       // Message is the human-readable message of the error.
+       Message string
+       // Code is the gRPC status code of the error.
+       Code codes.Code
+       // ErrorId is the unique id of the error. It can be used to fetch more 
details about
+       // the error using an additional RPC from the server.
+       ErrorId string
+       // Parameters are the parameters that are used to format the error 
message.
+       Parameters map[string]string
+       status     *status.Status
+}
+
+func (e SparkError) Error() string {
+       if e.Code == codes.Internal && e.SqlState != "" {
+               return fmt.Sprintf("[%s] %s. SQLSTATE: %s", e.ErrorClass, 
e.Message, e.SqlState)
+       } else {
+               return fmt.Sprintf("[%s] %s", e.Code.String(), e.Message)
+       }
+}
+
+// FromRPCError converts a gRPC error to a SparkError. If the error is not a 
gRPC error, it will
+// create a plain "UNKNOWN" GRPC status type. If no error was observed returns 
nil.
+func FromRPCError(e error) *SparkError {
+       status := status.Convert(e)
+       // If there was no error, simply pass through.
+       if status == nil {
+               return nil
+       }
+       result := &SparkError{
+               Message: status.Message(),
+               Code:    status.Code(),
+               status:  status,
+       }
+
+       // Now lets, check if we can extract the error info from the details.
+       for _, d := range status.Details() {
+               switch info := d.(type) {
+               case *errdetails.ErrorInfo:
+                       // Parse the parameters from the error details, but 
only parse them if
+                       // they're present.
+                       var params map[string]string
+                       if v, ok := info.GetMetadata()["messageParameters"]; ok 
{
+                               err := json.Unmarshal([]byte(v), &params)
+                               if err == nil {
+                                       // The message parameters is properly 
formatted JSON, if for some reason
+                                       // this is not the case, errors are 
ignored.
+                                       result.Parameters = params
+                               }
+                       }
+                       result.SqlState = info.GetMetadata()["sqlState"]
+                       result.ErrorClass = info.GetMetadata()["errorClass"]
+                       result.ErrorId = info.GetMetadata()["errorId"]
+                       result.Reason = info.Reason
+               }
+       }
+       return result
+}
diff --git a/spark/sparkerrors/errors_test.go b/spark/sparkerrors/errors_test.go
index d12a1fb..184ec97 100644
--- a/spark/sparkerrors/errors_test.go
+++ b/spark/sparkerrors/errors_test.go
@@ -18,6 +18,10 @@ package sparkerrors
 import (
        "testing"
 
+       "google.golang.org/genproto/googleapis/rpc/errdetails"
+       "google.golang.org/grpc/codes"
+       "google.golang.org/grpc/status"
+
        "github.com/stretchr/testify/assert"
 )
 
@@ -30,3 +34,148 @@ func TestErrorStringContainsErrorType(t *testing.T) {
        err := WithType(assert.AnError, ConnectionError)
        assert.Contains(t, err.Error(), ConnectionError.Error())
 }
+
+func TestGRPCErrorConversion(t *testing.T) {
+       err := status.Error(codes.Internal, "invalid argument")
+       se := FromRPCError(err)
+       assert.Equal(t, se.Code, codes.Internal)
+       assert.Equal(t, se.Message, "invalid argument")
+}
+
+func TestNonGRPCErrorsAreConvertedAsWell(t *testing.T) {
+       err := assert.AnError
+       se := FromRPCError(err)
+       assert.Equal(t, se.Code, codes.Unknown)
+       assert.Equal(t, se.Message, assert.AnError.Error())
+}
+
+func TestErrorDetailsExtractionFromGRPCStatus(t *testing.T) {
+       status := status.New(codes.Internal, "AnalysisException")
+       status, _ = status.WithDetails(&errdetails.ErrorInfo{
+               Reason:   "AnalysisException",
+               Domain:   "spark.sql",
+               Metadata: map[string]string{},
+       })
+
+       err := status.Err()
+       se := FromRPCError(err)
+       assert.Equal(t, codes.Internal, se.Code)
+       assert.Equal(t, "AnalysisException", se.Message)
+       assert.Equal(t, "AnalysisException", se.Reason)
+}
+
+func TestErrorDetailsWithSqlStateAndClass(t *testing.T) {
+       status := status.New(codes.Internal, "AnalysisException")
+       status, _ = status.WithDetails(&errdetails.ErrorInfo{
+               Reason: "AnalysisException",
+               Domain: "spark.sql",
+               Metadata: map[string]string{
+                       "sqlState":          "42000",
+                       "errorClass":        "ERROR_CLASS",
+                       "errorId":           "errorId",
+                       "messageParameters": "",
+               },
+       })
+
+       err := status.Err()
+       se := FromRPCError(err)
+       assert.Equal(t, codes.Internal, se.Code)
+       assert.Equal(t, "AnalysisException", se.Message)
+       assert.Equal(t, "AnalysisException", se.Reason)
+       assert.Equal(t, "42000", se.SqlState)
+       assert.Equal(t, "ERROR_CLASS", se.ErrorClass)
+       assert.Equal(t, "errorId", se.ErrorId)
+}
+
+func TestErrorDetailsWithMessageParameterParsing(t *testing.T) {
+       type param struct {
+               TestName string
+               Input    string
+               Expected map[string]string
+       }
+
+       params := []param{
+               {"empty input", "", nil},
+               {"empty input", "{", nil},
+               {"parse error", "{}", map[string]string{}},
+               {"valid input", "{\"key\":\"value\"}", map[string]string{"key": 
"value"}},
+       }
+
+       for _, p := range params {
+               t.Run(p.TestName, func(t *testing.T) {
+                       status := status.New(codes.Internal, 
"AnalysisException")
+                       status, _ = status.WithDetails(&errdetails.ErrorInfo{
+                               Reason: "AnalysisException",
+                               Domain: "spark.sql",
+                               Metadata: map[string]string{
+                                       "sqlState":          "42000",
+                                       "errorClass":        "ERROR_CLASS",
+                                       "errorId":           "errorId",
+                                       "messageParameters": p.Input,
+                               },
+                       })
+
+                       err := status.Err()
+                       se := FromRPCError(err)
+                       assert.Equal(t, codes.Internal, se.Code)
+                       assert.Equal(t, "AnalysisException", se.Message)
+                       assert.Equal(t, "AnalysisException", se.Reason)
+                       assert.Equal(t, "42000", se.SqlState)
+                       assert.Equal(t, "ERROR_CLASS", se.ErrorClass)
+                       assert.Equal(t, "errorId", se.ErrorId)
+                       assert.Equal(t, p.Expected, se.Parameters)
+               })
+       }
+}
+
+func TestSparkError_Error(t *testing.T) {
+       type fields struct {
+               SqlState   string
+               ErrorClass string
+               Reason     string
+               Message    string
+               Code       codes.Code
+               ErrorId    string
+               Parameters map[string]string
+               status     *status.Status
+       }
+       tests := []struct {
+               name   string
+               fields fields
+               want   string
+       }{
+               {
+                       "UNKNOWN",
+                       fields{
+                               Code:    codes.Unknown,
+                               Message: "Unknown error",
+                       },
+                       "[Unknown] Unknown error",
+               },
+               {
+                       "Analysis Exception",
+                       fields{
+                               SqlState:   "42703",
+                               ErrorClass: "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+                               Message:    "A column, variable, or function 
parameter with name `id2` cannot be resolved. Did you mean one of the 
following? [`id`]",
+                               Code:       codes.Internal,
+                       },
+                       "[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column, 
variable, or function parameter with name `id2` cannot be resolved. Did you 
mean one of the following? [`id`]. SQLSTATE: 42703",
+               },
+       }
+       for _, tt := range tests {
+               t.Run(tt.name, func(t *testing.T) {
+                       e := SparkError{
+                               SqlState:   tt.fields.SqlState,
+                               ErrorClass: tt.fields.ErrorClass,
+                               Reason:     tt.fields.Reason,
+                               Message:    tt.fields.Message,
+                               Code:       tt.fields.Code,
+                               ErrorId:    tt.fields.ErrorId,
+                               Parameters: tt.fields.Parameters,
+                               status:     tt.fields.status,
+                       }
+                       assert.Equalf(t, tt.want, e.Error(), "Error()")
+               })
+       }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to