This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c98afb  fix(go/adbc/driver/flightsql): fix stream timeout interceptor 
(#490)
1c98afb is described below

commit 1c98afb1a8f7ead4098a6951d4587b16710d239f
Author: Matt Topol <[email protected]>
AuthorDate: Mon Mar 6 16:00:33 2023 -0500

    fix(go/adbc/driver/flightsql): fix stream timeout interceptor (#490)
    
    Fixes #482.
---
 go/adbc/driver/flightsql/flightsql_adbc.go      | 105 +++++++++++++++++++++++-
 go/adbc/driver/flightsql/flightsql_adbc_test.go |  77 +++++++++++++++--
 go/adbc/driver/flightsql/flightsql_statement.go |  13 ++-
 go/adbc/driver/flightsql/utils.go               |   2 +
 4 files changed, 187 insertions(+), 10 deletions(-)

diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go 
b/go/adbc/driver/flightsql/flightsql_adbc.go
index c7187ff..2d00216 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -391,11 +391,112 @@ func unaryTimeoutInterceptor(ctx context.Context, method 
string, req, reply any,
        return invoker(ctx, method, req, reply, cc, opts...)
 }
 
+type streamEventType int
+
+const (
+       receiveEndEvent streamEventType = iota
+       errorEvent
+)
+
+type streamEvent struct {
+       Type streamEventType
+       Err  error
+}
+
+type wrappedClientStream struct {
+       grpc.ClientStream
+
+       desc       *grpc.StreamDesc
+       events     chan streamEvent
+       eventsDone chan struct{}
+}
+
+func (w *wrappedClientStream) RecvMsg(m any) error {
+       err := w.ClientStream.RecvMsg(m)
+
+       switch {
+       case err == nil && !w.desc.ServerStreams:
+               w.sendStreamEvent(receiveEndEvent, nil)
+       case err == io.EOF:
+               w.sendStreamEvent(receiveEndEvent, nil)
+       case err != nil:
+               w.sendStreamEvent(errorEvent, err)
+       }
+
+       return err
+}
+
+func (w *wrappedClientStream) SendMsg(m any) error {
+       err := w.ClientStream.SendMsg(m)
+       if err != nil {
+               w.sendStreamEvent(errorEvent, err)
+       }
+       return err
+}
+
+func (w *wrappedClientStream) Header() (metadata.MD, error) {
+       md, err := w.ClientStream.Header()
+       if err != nil {
+               w.sendStreamEvent(errorEvent, err)
+       }
+       return md, err
+}
+
+func (w *wrappedClientStream) CloseSend() error {
+       err := w.ClientStream.CloseSend()
+       if err != nil {
+               w.sendStreamEvent(errorEvent, err)
+       }
+       return err
+}
+
+func (w *wrappedClientStream) sendStreamEvent(eventType streamEventType, err 
error) {
+       select {
+       case <-w.eventsDone:
+       case w.events <- streamEvent{Type: eventType, Err: err}:
+       }
+}
+
 func streamTimeoutInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc 
*grpc.ClientConn, method string, streamer grpc.Streamer, opts 
...grpc.CallOption) (grpc.ClientStream, error) {
        if tm, ok := getTimeout(method, opts); ok {
                ctx, cancel := context.WithTimeout(ctx, tm)
-               defer cancel()
-               return streamer(ctx, desc, cc, method, opts...)
+               s, err := streamer(ctx, desc, cc, method, opts...)
+               if err != nil {
+                       defer cancel()
+                       return s, err
+               }
+
+               events, eventsDone := make(chan streamEvent), make(chan 
struct{})
+               go func() {
+                       defer close(eventsDone)
+                       defer cancel()
+
+                       for {
+                               select {
+                               case event := <-events:
+                                       // split by event type in case we want 
to add more logging
+                                       // or even adding in some telemetry in 
the future.
+                                       // Errors will already be propagated by 
the RecvMsg, SendMsg
+                                       // methods.
+                                       switch event.Type {
+                                       case receiveEndEvent:
+                                               return
+                                       case errorEvent:
+                                               return
+                                       }
+                               case <-ctx.Done():
+                                       return
+                               }
+                       }
+               }()
+
+               stream := &wrappedClientStream{
+                       ClientStream: s,
+                       desc:         desc,
+                       events:       events,
+                       eventsDone:   eventsDone,
+               }
+               return stream, nil
        }
 
        return streamer(ctx, desc, cc, method, opts...)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 44b4fd7..045b542 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -674,7 +674,27 @@ type TimeoutTestServer struct {
        flightsql.BaseServer
 }
 
-func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, _ 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
+func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt 
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, 
error) {
+       if string(tkt.GetStatementHandle()) == "sleep and succeed" {
+               time.Sleep(1 * time.Second)
+               sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+               rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, 
sc, strings.NewReader(`[{"a": 5}]`))
+               if err != nil {
+                       return nil, nil, err
+               }
+
+               ch := make(chan flight.StreamChunk)
+               go func() {
+                       defer close(ch)
+                       ch <- flight.StreamChunk{
+                               Data: rec,
+                               Desc: nil,
+                               Err:  nil,
+                       }
+               }()
+               return sc, ch, nil
+       }
+
        // wait till the context is cancelled
        <-ctx.Done()
        return nil, nil, arrow.ErrNotImplemented
@@ -702,10 +722,27 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx 
context.Context, cmd fli
                        TotalBytes:   -1,
                }
                return info, nil
+       case "notimeout":
+               time.Sleep(1 * time.Second)
+               tkt, _ := flightsql.CreateStatementQueryTicket([]byte("sleep 
and succeed"))
+               info := &flight.FlightInfo{
+                       FlightDescriptor: desc,
+                       Endpoint: []*flight.FlightEndpoint{
+                               {Ticket: &flight.Ticket{Ticket: tkt}},
+                       },
+                       TotalRecords: -1,
+                       TotalBytes:   -1,
+               }
+               return info, nil
        }
        return nil, arrow.ErrNotImplemented
 }
 
+func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req 
flightsql.ActionCreatePreparedStatementRequest) (result 
flightsql.ActionCreatePreparedStatementResult, err error) {
+       <-ctx.Done()
+       return result, arrow.ErrNotImplemented
+}
+
 type TimeoutTestSuite struct {
        suite.Suite
 
@@ -792,7 +829,7 @@ func (ts *TimeoutTestSuite) TestDoActionTimeout() {
        ts.Require().NoError(stmt.SetSqlQuery("fetch"))
        var adbcErr adbc.Error
        ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
-       ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
 }
 
 func (ts *TimeoutTestSuite) TestDoGetTimeout() {
@@ -807,20 +844,22 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
        var adbcErr adbc.Error
        _, _, err = stmt.ExecuteQuery(context.Background())
        ts.ErrorAs(err, &adbcErr)
-       ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
 }
 
 func (ts *TimeoutTestSuite) TestDoPutTimeout() {
        ts.NoError(ts.cnxn.(adbc.PostInitOptions).
-               SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1"))
+               SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "5.1"))
 
        stmt, err := ts.cnxn.NewStatement()
        ts.Require().NoError(err)
        defer stmt.Close()
 
        ts.Require().NoError(stmt.SetSqlQuery("timeout"))
+       var adbcErr adbc.Error
        _, err = stmt.ExecuteUpdate(context.Background())
-       ts.Error(err)
+       ts.ErrorAs(err, &adbcErr)
+       ts.Equal(adbc.StatusTimeout, adbcErr.Code)
 }
 
 func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
@@ -838,6 +877,34 @@ func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
        ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code)
 }
 
+func (ts *TimeoutTestSuite) TestDontTimeout() {
+       ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+               SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "2.0"))
+       ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+               SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "2.0"))
+
+       stmt, err := ts.cnxn.NewStatement()
+       ts.Require().NoError(err)
+       defer stmt.Close()
+
+       ts.Require().NoError(stmt.SetSqlQuery("notimeout"))
+       // GetFlightInfo will sleep for one second and DoGet will also
+       // sleep for one second. But our timeout is 2 seconds, which is
+       // per-operation. So we shouldn't time out and all should succeed.
+       rr, _, err := stmt.ExecuteQuery(context.Background())
+       ts.Require().NoError(err)
+       defer rr.Release()
+
+       ts.True(rr.Next())
+       rec := rr.Record()
+
+       sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+       expected, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc, 
strings.NewReader(`[{"a": 5}]`))
+       ts.Require().NoError(err)
+       defer expected.Release()
+       ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s", 
expected, rec)
+}
+
 type TLSTests struct {
        suite.Suite
 
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go 
b/go/adbc/driver/flightsql/flightsql_statement.go
index 980ade5..1429d07 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -249,13 +249,20 @@ func (s *statement) ExecuteQuery(ctx context.Context) 
(rdr array.RecordReader, n
 
 // ExecuteUpdate executes a statement that does not generate a result
 // set. It returns the number of rows affected if known, otherwise -1.
-func (s *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
+func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
        ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
+
        if s.prepared != nil {
-               return s.prepared.ExecuteUpdate(ctx, s.timeouts)
+               n, err = s.prepared.ExecuteUpdate(ctx, s.timeouts)
+       } else {
+               n, err = s.query.executeUpdate(ctx, s.cnxn, s.timeouts)
+       }
+
+       if err != nil {
+               err = adbcFromFlightStatus(err)
        }
 
-       return s.query.executeUpdate(ctx, s.cnxn, s.timeouts)
+       return
 }
 
 // Prepare turns this statement into a prepared statement to be executed
diff --git a/go/adbc/driver/flightsql/utils.go 
b/go/adbc/driver/flightsql/utils.go
index e3dd654..6ffd253 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -52,6 +52,8 @@ func adbcFromFlightStatus(err error) error {
                adbcCode = adbc.StatusNotImplemented
        case codes.PermissionDenied:
                adbcCode = adbc.StatusUnauthorized
+       case codes.DeadlineExceeded:
+               adbcCode = adbc.StatusTimeout
        default:
                adbcCode = adbc.StatusUnknown
        }

Reply via email to