This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 1c98afb fix(go/adbc/driver/flightsql): fix stream timeout interceptor
(#490)
1c98afb is described below
commit 1c98afb1a8f7ead4098a6951d4587b16710d239f
Author: Matt Topol <[email protected]>
AuthorDate: Mon Mar 6 16:00:33 2023 -0500
fix(go/adbc/driver/flightsql): fix stream timeout interceptor (#490)
Fixes #482.
---
go/adbc/driver/flightsql/flightsql_adbc.go | 105 +++++++++++++++++++++++-
go/adbc/driver/flightsql/flightsql_adbc_test.go | 77 +++++++++++++++--
go/adbc/driver/flightsql/flightsql_statement.go | 13 ++-
go/adbc/driver/flightsql/utils.go | 2 +
4 files changed, 187 insertions(+), 10 deletions(-)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc.go
b/go/adbc/driver/flightsql/flightsql_adbc.go
index c7187ff..2d00216 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc.go
@@ -391,11 +391,112 @@ func unaryTimeoutInterceptor(ctx context.Context, method
string, req, reply any,
return invoker(ctx, method, req, reply, cc, opts...)
}
+type streamEventType int
+
+const (
+ receiveEndEvent streamEventType = iota
+ errorEvent
+)
+
+type streamEvent struct {
+ Type streamEventType
+ Err error
+}
+
+type wrappedClientStream struct {
+ grpc.ClientStream
+
+ desc *grpc.StreamDesc
+ events chan streamEvent
+ eventsDone chan struct{}
+}
+
+func (w *wrappedClientStream) RecvMsg(m any) error {
+ err := w.ClientStream.RecvMsg(m)
+
+ switch {
+ case err == nil && !w.desc.ServerStreams:
+ w.sendStreamEvent(receiveEndEvent, nil)
+ case err == io.EOF:
+ w.sendStreamEvent(receiveEndEvent, nil)
+ case err != nil:
+ w.sendStreamEvent(errorEvent, err)
+ }
+
+ return err
+}
+
+func (w *wrappedClientStream) SendMsg(m any) error {
+ err := w.ClientStream.SendMsg(m)
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return err
+}
+
+func (w *wrappedClientStream) Header() (metadata.MD, error) {
+ md, err := w.ClientStream.Header()
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return md, err
+}
+
+func (w *wrappedClientStream) CloseSend() error {
+ err := w.ClientStream.CloseSend()
+ if err != nil {
+ w.sendStreamEvent(errorEvent, err)
+ }
+ return err
+}
+
+func (w *wrappedClientStream) sendStreamEvent(eventType streamEventType, err
error) {
+ select {
+ case <-w.eventsDone:
+ case w.events <- streamEvent{Type: eventType, Err: err}:
+ }
+}
+
func streamTimeoutInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc
*grpc.ClientConn, method string, streamer grpc.Streamer, opts
...grpc.CallOption) (grpc.ClientStream, error) {
if tm, ok := getTimeout(method, opts); ok {
ctx, cancel := context.WithTimeout(ctx, tm)
- defer cancel()
- return streamer(ctx, desc, cc, method, opts...)
+ s, err := streamer(ctx, desc, cc, method, opts...)
+ if err != nil {
+ defer cancel()
+ return s, err
+ }
+
+ events, eventsDone := make(chan streamEvent), make(chan
struct{})
+ go func() {
+ defer close(eventsDone)
+ defer cancel()
+
+ for {
+ select {
+ case event := <-events:
+ // split by event type in case we want
to add more logging
+ // or even adding in some telemetry in
the future.
+ // Errors will already be propagated by
the RecvMsg, SendMsg
+ // methods.
+ switch event.Type {
+ case receiveEndEvent:
+ return
+ case errorEvent:
+ return
+ }
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ stream := &wrappedClientStream{
+ ClientStream: s,
+ desc: desc,
+ events: events,
+ eventsDone: eventsDone,
+ }
+ return stream, nil
}
return streamer(ctx, desc, cc, method, opts...)
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_test.go
b/go/adbc/driver/flightsql/flightsql_adbc_test.go
index 44b4fd7..045b542 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_test.go
@@ -674,7 +674,27 @@ type TimeoutTestServer struct {
flightsql.BaseServer
}
-func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, _
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk,
error) {
+func (ts *TimeoutTestServer) DoGetStatement(ctx context.Context, tkt
flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk,
error) {
+ if string(tkt.GetStatementHandle()) == "sleep and succeed" {
+ time.Sleep(1 * time.Second)
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ rec, _, err := array.RecordFromJSON(memory.DefaultAllocator,
sc, strings.NewReader(`[{"a": 5}]`))
+ if err != nil {
+ return nil, nil, err
+ }
+
+ ch := make(chan flight.StreamChunk)
+ go func() {
+ defer close(ch)
+ ch <- flight.StreamChunk{
+ Data: rec,
+ Desc: nil,
+ Err: nil,
+ }
+ }()
+ return sc, ch, nil
+ }
+
// wait till the context is cancelled
<-ctx.Done()
return nil, nil, arrow.ErrNotImplemented
@@ -702,10 +722,27 @@ func (ts *TimeoutTestServer) GetFlightInfoStatement(ctx
context.Context, cmd fli
TotalBytes: -1,
}
return info, nil
+ case "notimeout":
+ time.Sleep(1 * time.Second)
+ tkt, _ := flightsql.CreateStatementQueryTicket([]byte("sleep
and succeed"))
+ info := &flight.FlightInfo{
+ FlightDescriptor: desc,
+ Endpoint: []*flight.FlightEndpoint{
+ {Ticket: &flight.Ticket{Ticket: tkt}},
+ },
+ TotalRecords: -1,
+ TotalBytes: -1,
+ }
+ return info, nil
}
return nil, arrow.ErrNotImplemented
}
+func (ts *TimeoutTestServer) CreatePreparedStatement(ctx context.Context, req
flightsql.ActionCreatePreparedStatementRequest) (result
flightsql.ActionCreatePreparedStatementResult, err error) {
+ <-ctx.Done()
+ return result, arrow.ErrNotImplemented
+}
+
type TimeoutTestSuite struct {
suite.Suite
@@ -792,7 +829,7 @@ func (ts *TimeoutTestSuite) TestDoActionTimeout() {
ts.Require().NoError(stmt.SetSqlQuery("fetch"))
var adbcErr adbc.Error
ts.ErrorAs(stmt.Prepare(context.Background()), &adbcErr)
- ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code)
}
func (ts *TimeoutTestSuite) TestDoGetTimeout() {
@@ -807,20 +844,22 @@ func (ts *TimeoutTestSuite) TestDoGetTimeout() {
var adbcErr adbc.Error
_, _, err = stmt.ExecuteQuery(context.Background())
ts.ErrorAs(err, &adbcErr)
- ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code)
}
func (ts *TimeoutTestSuite) TestDoPutTimeout() {
ts.NoError(ts.cnxn.(adbc.PostInitOptions).
- SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "0.1"))
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.update", "5.1"))
stmt, err := ts.cnxn.NewStatement()
ts.Require().NoError(err)
defer stmt.Close()
ts.Require().NoError(stmt.SetSqlQuery("timeout"))
+ var adbcErr adbc.Error
_, err = stmt.ExecuteUpdate(context.Background())
- ts.Error(err)
+ ts.ErrorAs(err, &adbcErr)
+ ts.Equal(adbc.StatusTimeout, adbcErr.Code)
}
func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
@@ -838,6 +877,34 @@ func (ts *TimeoutTestSuite) TestGetFlightInfoTimeout() {
ts.NotEqual(adbc.StatusNotImplemented, adbcErr.Code)
}
+func (ts *TimeoutTestSuite) TestDontTimeout() {
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.fetch", "2.0"))
+ ts.NoError(ts.cnxn.(adbc.PostInitOptions).
+ SetOption("adbc.flight.sql.rpc.timeout_seconds.query", "2.0"))
+
+ stmt, err := ts.cnxn.NewStatement()
+ ts.Require().NoError(err)
+ defer stmt.Close()
+
+ ts.Require().NoError(stmt.SetSqlQuery("notimeout"))
+ // GetFlightInfo will sleep for one second and DoGet will also
+ // sleep for one second. But our timeout is 2 seconds, which is
+ // per-operation. So we shouldn't time out and all should succeed.
+ rr, _, err := stmt.ExecuteQuery(context.Background())
+ ts.Require().NoError(err)
+ defer rr.Release()
+
+ ts.True(rr.Next())
+ rec := rr.Record()
+
+ sc := arrow.NewSchema([]arrow.Field{{Name: "a", Type:
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil)
+ expected, _, err := array.RecordFromJSON(memory.DefaultAllocator, sc,
strings.NewReader(`[{"a": 5}]`))
+ ts.Require().NoError(err)
+ defer expected.Release()
+ ts.Truef(array.RecordEqual(rec, expected), "expected: %s\nactual: %s",
expected, rec)
+}
+
type TLSTests struct {
suite.Suite
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go
b/go/adbc/driver/flightsql/flightsql_statement.go
index 980ade5..1429d07 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -249,13 +249,20 @@ func (s *statement) ExecuteQuery(ctx context.Context)
(rdr array.RecordReader, n
// ExecuteUpdate executes a statement that does not generate a result
// set. It returns the number of rows affected if known, otherwise -1.
-func (s *statement) ExecuteUpdate(ctx context.Context) (int64, error) {
+func (s *statement) ExecuteUpdate(ctx context.Context) (n int64, err error) {
ctx = metadata.NewOutgoingContext(ctx, s.hdrs)
+
if s.prepared != nil {
- return s.prepared.ExecuteUpdate(ctx, s.timeouts)
+ n, err = s.prepared.ExecuteUpdate(ctx, s.timeouts)
+ } else {
+ n, err = s.query.executeUpdate(ctx, s.cnxn, s.timeouts)
+ }
+
+ if err != nil {
+ err = adbcFromFlightStatus(err)
}
- return s.query.executeUpdate(ctx, s.cnxn, s.timeouts)
+ return
}
// Prepare turns this statement into a prepared statement to be executed
diff --git a/go/adbc/driver/flightsql/utils.go
b/go/adbc/driver/flightsql/utils.go
index e3dd654..6ffd253 100644
--- a/go/adbc/driver/flightsql/utils.go
+++ b/go/adbc/driver/flightsql/utils.go
@@ -52,6 +52,8 @@ func adbcFromFlightStatus(err error) error {
adbcCode = adbc.StatusNotImplemented
case codes.PermissionDenied:
adbcCode = adbc.StatusUnauthorized
+ case codes.DeadlineExceeded:
+ adbcCode = adbc.StatusTimeout
default:
adbcCode = adbc.StatusUnknown
}