This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 6263cfd5 feat(go/adbc/driver/flightsql): add context to gRPC errors 
(#921)
6263cfd5 is described below

commit 6263cfd5c99bf53c9c3672bacf03b6d604a824c4
Author: David Li <[email protected]>
AuthorDate: Thu Jul 20 18:34:17 2023 -0400

    feat(go/adbc/driver/flightsql): add context to gRPC errors (#921)
    
    See #862.
    
    Example:
    
    ```
    Internal: SqlState: , msg: [FlightSQL] Ballista Error: 
General("scheduler::from_proto(Action) invalid or missing action") (Internal; 
DoGet: endpoint 0: [uri:"grpc+tcp://172.24.0.5:50051"])
    ```
    
    Now we can see that the error comes from issuing a DoGet against a
    particular location.
---
 go/adbc/driver/flightsql/flightsql_adbc.go         | 48 +++++++++++-----------
 .../driver/flightsql/flightsql_adbc_server_test.go |  2 +-
 go/adbc/driver/flightsql/flightsql_statement.go    | 10 ++---
 go/adbc/driver/flightsql/record_reader.go          |  4 +-
 go/adbc/driver/flightsql/utils.go                  |  7 +++-
 5 files changed, 37 insertions(+), 34 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index e038354c..1ae99a6a 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -892,10 +892,10 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes 
[]adbc.InfoCode) (array.Re
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        info, err := c.cl.GetSqlInfo(ctx, translated, c.timeouts)
        if err == nil {
-               for _, endpoint := range info.Endpoint {
+               for i, endpoint := range info.Endpoint {
                        rdr, err := doGet(ctx, c.cl, endpoint, c.clientCache, 
c.timeouts)
                        if err != nil {
-                               return nil, adbcFromFlightStatus(err)
+                               return nil, adbcFromFlightStatus(err, 
"GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
                        }
 
                        for rdr.Next() {
@@ -922,11 +922,11 @@ func (c *cnxn) GetInfo(ctx context.Context, infoCodes 
[]adbc.InfoCode) (array.Re
                        }
 
                        if rdr.Err() != nil {
-                               return nil, adbcFromFlightStatus(rdr.Err())
+                               return nil, adbcFromFlightStatus(rdr.Err(), 
"GetInfo(DoGet): endpoint %d: %s", i, endpoint.Location)
                        }
                }
        } else if grpcstatus.Code(err) != grpccodes.Unimplemented {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetInfo(GetSqlInfo)")
        }
 
        final := bldr.NewRecord()
@@ -1032,12 +1032,12 @@ func (c *cnxn) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog *
        // To avoid an N+1 query problem, we assume result sets here will fit 
in memory and build up a single response.
        info, err := c.cl.GetCatalogs(ctx)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
        }
 
        rdr, err := c.readInfo(ctx, schema_ref.Catalogs, info)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
        }
        defer rdr.Release()
 
@@ -1058,7 +1058,7 @@ func (c *cnxn) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog *
        }
 
        if err = rdr.Err(); err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetObjects(GetCatalogs)")
        }
 
        return g.Finish()
@@ -1069,7 +1069,7 @@ func (c *cnxn) readInfo(ctx context.Context, 
expectedSchema *arrow.Schema, info
        // use a default queueSize for the reader
        rdr, err := newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 
5)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "DoGet")
        }
 
        if !rdr.Schema().Equal(expectedSchema) {
@@ -1091,12 +1091,12 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, 
depth adbc.ObjectDepth,
        // Pre-populate the map of which schemas are in which catalogs
        info, err := c.cl.GetDBSchemas(ctx, 
&flightsql.GetDBSchemasOpts{DbSchemaFilterPattern: dbSchema})
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, 
"GetObjects(GetDBSchemas)")
        }
 
        rdr, err := c.readInfo(ctx, schema_ref.DBSchemas, info)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, 
"GetObjects(GetDBSchemas)")
        }
        defer rdr.Release()
 
@@ -1117,7 +1117,7 @@ func (c *cnxn) getObjectsDbSchemas(ctx context.Context, 
depth adbc.ObjectDepth,
 
        if rdr.Err() != nil {
                result = nil
-               err = adbcFromFlightStatus(rdr.Err())
+               err = adbcFromFlightStatus(rdr.Err(), 
"GetObjects(GetDBSchemas)")
        }
        return
 }
@@ -1137,7 +1137,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, 
depth adbc.ObjectDepth, cat
                IncludeSchema:          includeSchema,
        })
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
        }
 
        expectedSchema := schema_ref.Tables
@@ -1146,7 +1146,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, 
depth adbc.ObjectDepth, cat
        }
        rdr, err := c.readInfo(ctx, expectedSchema, info)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetObjects(GetTables)")
        }
        defer rdr.Release()
 
@@ -1195,7 +1195,7 @@ func (c *cnxn) getObjectsTables(ctx context.Context, 
depth adbc.ObjectDepth, cat
 
        if rdr.Err() != nil {
                result = nil
-               err = adbcFromFlightStatus(rdr.Err())
+               err = adbcFromFlightStatus(rdr.Err(), "GetObjects(GetTables)")
        }
        return
 }
@@ -1211,12 +1211,12 @@ func (c *cnxn) GetTableSchema(ctx context.Context, 
catalog *string, dbSchema *st
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        info, err := c.cl.GetTables(ctx, opts, c.timeouts)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, 
"GetTableSchema(GetTables)")
        }
 
        rdr, err := doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, 
c.timeouts)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
        }
        defer rdr.Release()
 
@@ -1228,7 +1228,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, 
catalog *string, dbSchema *st
                                Code: adbc.StatusNotFound,
                        }
                }
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetTableSchema(DoGet)")
        }
 
        if rec.NumRows() == 0 {
@@ -1246,7 +1246,7 @@ func (c *cnxn) GetTableSchema(ctx context.Context, 
catalog *string, dbSchema *st
        schemaBytes := rec.Column(4).(*array.Binary).Value(0)
        s, err := flight.DeserializeSchema(schemaBytes, c.db.alloc)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetTableSchema")
        }
        return s, nil
 }
@@ -1262,7 +1262,7 @@ func (c *cnxn) GetTableTypes(ctx context.Context) 
(array.RecordReader, error) {
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        info, err := c.cl.GetTableTypes(ctx, c.timeouts)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "GetTableTypes")
        }
 
        return newRecordReader(ctx, c.db.alloc, c.cl, info, c.clientCache, 5)
@@ -1289,12 +1289,12 @@ func (c *cnxn) Commit(ctx context.Context) error {
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        err := c.txn.Commit(ctx, c.timeouts)
        if err != nil {
-               return adbcFromFlightStatus(err)
+               return adbcFromFlightStatus(err, "Commit")
        }
 
        c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
        if err != nil {
-               return adbcFromFlightStatus(err)
+               return adbcFromFlightStatus(err, "BeginTransaction")
        }
        return nil
 }
@@ -1320,12 +1320,12 @@ func (c *cnxn) Rollback(ctx context.Context) error {
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        err := c.txn.Rollback(ctx, c.timeouts)
        if err != nil {
-               return adbcFromFlightStatus(err)
+               return adbcFromFlightStatus(err, "Rollback")
        }
 
        c.txn, err = c.cl.BeginTransaction(ctx, c.timeouts)
        if err != nil {
-               return adbcFromFlightStatus(err)
+               return adbcFromFlightStatus(err, "BeginTransaction")
        }
        return nil
 }
@@ -1428,7 +1428,7 @@ func (c *cnxn) ReadPartition(ctx context.Context, 
serializedPartition []byte) (r
        ctx = metadata.NewOutgoingContext(ctx, c.hdrs)
        rdr, err = doGet(ctx, c.cl, info.Endpoint[0], c.clientCache, c.timeouts)
        if err != nil {
-               return nil, adbcFromFlightStatus(err)
+               return nil, adbcFromFlightStatus(err, "ReadPartition(DoGet)")
        }
        return rdr, nil
 }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 61a46db1..dd6171c4 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -338,7 +338,7 @@ func (ts *TimeoutTests) TestDoActionTimeout() {
        ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
        ts.Equal(adbc.StatusTimeout, adbcErr.Code, adbcErr.Error())
        // Exact match - we don't want extra fluff in the message
-       ts.Equal("context deadline exceeded", adbcErr.Msg)
+       ts.Equal("[FlightSQL] context deadline exceeded (DeadlineExceeded; 
Prepare)", adbcErr.Msg)
 }
 
 func (ts *TimeoutTests) TestDoGetTimeout() {
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go 
b/go/adbc/driver/flightsql/flightsql_statement.go
index c7f074a8..3e7d20e1 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -239,7 +239,7 @@ func (s *statement) ExecuteQuery(ctx context.Context) (rdr 
array.RecordReader, n
        }
 
        if err != nil {
-               return nil, -1, adbcFromFlightStatus(err)
+               return nil, -1, adbcFromFlightStatus(err, "ExecuteQuery")
        }
 
        nrec = info.TotalRecords
@@ -259,7 +259,7 @@ func (s *statement) ExecuteUpdate(ctx context.Context) (n 
int64, err error) {
        }
 
        if err != nil {
-               err = adbcFromFlightStatus(err)
+               err = adbcFromFlightStatus(err, "ExecuteUpdate")
        }
 
        return
@@ -271,7 +271,7 @@ func (s *statement) Prepare(ctx context.Context) error {
        ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
        prep, err := s.query.prepare(ctx, s.cnxn, s.timeouts)
        if err != nil {
-               return adbcFromFlightStatus(err)
+               return adbcFromFlightStatus(err, "Prepare")
        }
        s.prepared = prep
        return nil
@@ -394,13 +394,13 @@ func (s *statement) ExecutePartitions(ctx 
context.Context) (*arrow.Schema, adbc.
        }
 
        if err != nil {
-               return nil, out, -1, adbcFromFlightStatus(err)
+               return nil, out, -1, adbcFromFlightStatus(err, 
"ExecutePartitions")
        }
 
        if len(info.Schema) > 0 {
                sc, err = flight.DeserializeSchema(info.Schema, s.alloc)
                if err != nil {
-                       return nil, out, -1, adbcFromFlightStatus(err)
+                       return nil, out, -1, adbcFromFlightStatus(err, 
"ExecutePartitions: could not deserialize FlightInfo schema:")
                }
        }
 
diff --git a/go/adbc/driver/flightsql/record_reader.go 
b/go/adbc/driver/flightsql/record_reader.go
index 409ce58e..c2721a7a 100644
--- a/go/adbc/driver/flightsql/record_reader.go
+++ b/go/adbc/driver/flightsql/record_reader.go
@@ -90,7 +90,7 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, cl *flightsql.
        } else {
                rdr, err := doGet(ctx, cl, endpoints[0], clCache, opts...)
                if err != nil {
-                       return nil, adbcFromFlightStatus(err)
+                       return nil, adbcFromFlightStatus(err, "DoGet: endpoint 
0: remote: %s", endpoints[0].Location)
                }
                schema = rdr.Schema()
                group.Go(func() error {
@@ -135,7 +135,7 @@ func newRecordReader(ctx context.Context, alloc 
memory.Allocator, cl *flightsql.
 
                        rdr, err := doGet(ctx, cl, endpoint, clCache, opts...)
                        if err != nil {
-                               return err
+                               return adbcFromFlightStatus(err, "DoGet: 
endpoint %d: %s", endpointIndex, endpoint.Location)
                        }
                        defer rdr.Release()
 
diff --git a/go/adbc/driver/flightsql/utils.go 
b/go/adbc/driver/flightsql/utils.go
index cbf9048f..e4cf2768 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -18,12 +18,14 @@
 package flightsql
 
 import (
+       "fmt"
+
        "github.com/apache/arrow-adbc/go/adbc"
        "google.golang.org/grpc/codes"
        "google.golang.org/grpc/status"
 )
 
-func adbcFromFlightStatus(err error) error {
+func adbcFromFlightStatus(err error, context string, args ...any) error {
        if _, ok := err.(adbc.Error); ok {
                return err
        }
@@ -70,8 +72,9 @@ func adbcFromFlightStatus(err error) error {
                adbcCode = adbc.StatusUnknown
        }
 
+       // People don't read error messages, so backload the context and 
frontload the server error
        return adbc.Error{
-               Msg:  grpcStatus.Message(),
+               Msg:  fmt.Sprintf("[FlightSQL] %s (%s; %s)", 
grpcStatus.Message(), grpcStatus.Code(), fmt.Sprintf(context, args...)),
                Code: adbcCode,
        }
 }

Reply via email to