This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new d9629615 feat(go/adbc/driver/flightsql): expose FlightInfo during 
polling (#1582)
d9629615 is described below

commit d9629615155b5bfebf288fd56c1f6272ea22c3bb
Author: David Li <[email protected]>
AuthorDate: Fri Mar 1 18:20:59 2024 -0500

    feat(go/adbc/driver/flightsql): expose FlightInfo during polling (#1582)
    
    Fixes #1571.
---
 go/adbc/driver/flightsql/cmd/testserver/main.go    | 25 +++++-
 .../driver/flightsql/flightsql_adbc_server_test.go | 94 ++++++++++++++++++++++
 go/adbc/driver/flightsql/flightsql_driver.go       |  1 +
 go/adbc/driver/flightsql/flightsql_statement.go    | 21 +++++
 .../adbc_driver_flightsql/__init__.py              |  9 +++
 .../tests/test_incremental.py                      | 51 ++++++++++++
 6 files changed, 197 insertions(+), 4 deletions(-)

diff --git a/go/adbc/driver/flightsql/cmd/testserver/main.go 
b/go/adbc/driver/flightsql/cmd/testserver/main.go
index 22a928ad..8ce65c9f 100644
--- a/go/adbc/driver/flightsql/cmd/testserver/main.go
+++ b/go/adbc/driver/flightsql/cmd/testserver/main.go
@@ -134,15 +134,32 @@ func (srv *ExampleServer) PollFlightInfo(ctx 
context.Context, desc *flight.Fligh
                return nil, err
        }
 
-       srv.pollingStatus[val.Value]--
-       progress := srv.pollingStatus[val.Value]
-
        ticket, err := flightsql.CreateStatementQueryTicket([]byte(val.Value))
        if err != nil {
                return nil, err
        }
 
-       endpoints := make([]*flight.FlightEndpoint, 5-progress)
+       if val.Value == "forever" {
+               srv.pollingStatus[val.Value]++
+               return &flight.PollInfo{
+                       Info: &flight.FlightInfo{
+                               Schema:           
flight.SerializeSchema(arrow.NewSchema([]arrow.Field{{Name: "ints", Type: 
arrow.PrimitiveTypes.Int32, Nullable: true}}, nil), srv.Alloc),
+                               Endpoint:         
[]*flight.FlightEndpoint{{Ticket: &flight.Ticket{Ticket: ticket}}},
+                               FlightDescriptor: desc,
+                               TotalRecords:     -1,
+                               TotalBytes:       -1,
+                               AppMetadata:      []byte("app metadata"),
+                       },
+                       FlightDescriptor: desc,
+                       Progress:         
proto.Float64(float64(srv.pollingStatus[val.Value]) / 100.0),
+               }, nil
+       }
+
+       srv.pollingStatus[val.Value]--
+       progress := srv.pollingStatus[val.Value]
+
+       numEndpoints := 5 - progress
+       endpoints := make([]*flight.FlightEndpoint, numEndpoints)
        for i := range endpoints {
                endpoints[i] = &flight.FlightEndpoint{Ticket: 
&flight.Ticket{Ticket: ticket}}
        }
diff --git a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go 
b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
index 66e94da4..78ae01b4 100644
--- a/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
+++ b/go/adbc/driver/flightsql/flightsql_adbc_server_test.go
@@ -546,6 +546,32 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx 
context.Context, desc *
                return nil, status.Errorf(codes.NotFound, "Query ID not found")
        }
 
+       if query.query == "infinite" {
+               query.nextIndex++
+
+               descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: 
queryId})
+               if err != nil {
+                       return nil, err
+               }
+               return &flight.PollInfo{
+                       Info: &flight.FlightInfo{
+                               Schema: nil,
+                               Endpoint: []*flight.FlightEndpoint{{
+                                       Ticket: &flight.Ticket{
+                                               Ticket: []byte{},
+                                       },
+                               }},
+                               AppMetadata: []byte("app metadata"),
+                       },
+                       FlightDescriptor: &flight.FlightDescriptor{
+                               Type: flight.DescriptorCMD,
+                               Cmd:  descriptor,
+                       },
+                       // always makes a bit of progress, never gets anywhere
+                       Progress: proto.Float64(float64(query.nextIndex) / 
100.0),
+               }, nil
+       }
+
        testCase, ok := srv.testCases[query.query]
        if !ok {
                if query.query == "unavailable" {
@@ -581,6 +607,32 @@ func (srv *IncrementalPollTestServer) 
PollFlightInfoStatement(ctx context.Contex
                }
 
                return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], 
queryId)
+       } else if query.GetQuery() == "infinite" {
+               srv.queries[queryId] = &IncrementalQuery{
+                       query:     query.GetQuery(),
+                       nextIndex: 0,
+               }
+
+               descriptor, err := proto.Marshal(&wrapperspb.StringValue{Value: 
queryId})
+               if err != nil {
+                       return nil, err
+               }
+               return &flight.PollInfo{
+                       Info: &flight.FlightInfo{
+                               Schema: nil,
+                               Endpoint: []*flight.FlightEndpoint{{
+                                       Ticket: &flight.Ticket{
+                                               Ticket: []byte{},
+                                       },
+                               }},
+                               AppMetadata: []byte("app metadata"),
+                       },
+                       FlightDescriptor: &flight.FlightDescriptor{
+                               Type: flight.DescriptorCMD,
+                               Cmd:  descriptor,
+                       },
+                       Progress: proto.Float64(0),
+               }, nil
        }
 
        testCase, ok := srv.testCases[query.GetQuery()]
@@ -790,6 +842,48 @@ func (ts *IncrementalPollTests) TestOptionValue() {
        ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
 }
 
+func (ts *IncrementalPollTests) TestAppMetadata() {
+       ctx, cancel := context.WithCancel(context.Background())
+       stmt, err := ts.cnxn.NewStatement()
+       ts.NoError(err)
+       defer stmt.Close()
+
+       ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, 
adbc.OptionValueEnabled))
+
+       ts.NoError(stmt.SetSqlQuery("infinite"))
+       _, partitions, _, err := stmt.ExecutePartitions(ctx)
+       ts.NoError(err)
+       ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)
+
+       progress := 0.0
+       go func() {
+               var err error
+               var info []byte
+               for {
+                       // While the below is stuck, we should be able to get 
the app metadata and progress
+                       progress, err = 
stmt.(adbc.GetSetOptions).GetOptionDouble(adbc.OptionKeyProgress)
+                       ts.NoError(err)
+
+                       info, err = 
stmt.(adbc.GetSetOptions).GetOptionBytes(driver.OptionLastFlightInfo)
+                       ts.NoError(err)
+                       var flightInfo flight.FlightInfo
+                       ts.NoError(proto.Unmarshal(info, &flightInfo))
+                       ts.Equal([]byte("app metadata"), flightInfo.AppMetadata)
+
+                       if progress > 0.03 {
+                               break
+                       }
+               }
+               cancel()
+       }()
+
+       // will get stuck forever, but will "make progress"
+       _, _, _, err = stmt.ExecutePartitions(ctx)
+       var adbcErr adbc.Error
+       ts.ErrorAs(err, &adbcErr)
+       ts.Equal(adbc.StatusCancelled, adbcErr.Code)
+}
+
 func (ts *IncrementalPollTests) TestUnavailable() {
        // An error from the server should not tear down all the state.  We
        // should be able to retry the request.
diff --git a/go/adbc/driver/flightsql/flightsql_driver.go 
b/go/adbc/driver/flightsql/flightsql_driver.go
index 4914ad1c..df1ae688 100644
--- a/go/adbc/driver/flightsql/flightsql_driver.go
+++ b/go/adbc/driver/flightsql/flightsql_driver.go
@@ -60,6 +60,7 @@ const (
        OptionTimeoutUpdate       = "adbc.flight.sql.rpc.timeout_seconds.update"
        OptionRPCCallHeaderPrefix = "adbc.flight.sql.rpc.call_header."
        OptionCookieMiddleware    = "adbc.flight.sql.rpc.with_cookie_middleware"
+       OptionLastFlightInfo      = 
"adbc.flight.sql.statement.exec.last_flight_info"
        infoDriverName            = "ADBC Flight SQL Driver - Go"
 )
 
diff --git a/go/adbc/driver/flightsql/flightsql_statement.go 
b/go/adbc/driver/flightsql/flightsql_statement.go
index a1e33fd3..d78b653c 100644
--- a/go/adbc/driver/flightsql/flightsql_statement.go
+++ b/go/adbc/driver/flightsql/flightsql_statement.go
@@ -166,6 +166,8 @@ type statement struct {
        timeouts         timeoutOption
        incrementalState *incrementalState
        progress         float64
+       // may seem redundant, but incrementalState isn't locked
+       lastInfo atomic.Pointer[flight.FlightInfo]
 }
 
 func (s *statement) closePreparedStatement() error {
@@ -184,6 +186,7 @@ func (s *statement) clearIncrementalQuery() error {
                        }
                }
                s.incrementalState = &incrementalState{}
+               s.lastInfo.Store(nil)
        }
        return nil
 }
@@ -249,6 +252,21 @@ func (s *statement) GetOption(key string) (string, error) {
        }
 }
 func (s *statement) GetOptionBytes(key string) ([]byte, error) {
+       switch key {
+       case OptionLastFlightInfo:
+               info := s.lastInfo.Load()
+               if info == nil {
+                       return []byte{}, nil
+               }
+               serialized, err := proto.Marshal(info)
+               if err != nil {
+                       return nil, adbc.Error{
+                               Msg:  fmt.Sprintf("[Flight SQL] Could not 
serialize result for '%s': %s", key, err.Error()),
+                               Code: adbc.StatusInternal,
+                       }
+               }
+               return serialized, nil
+       }
        return nil, adbc.Error{
                Msg:  fmt.Sprintf("[Flight SQL] Unknown statement option '%s'", 
key),
                Code: adbc.StatusNotFound,
@@ -594,6 +612,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) 
(*arrow.Schema, adbc.
                        // Reset the statement for reuse
                        s.incrementalState = &incrementalState{}
                        atomicStoreFloat64(&s.progress, 0.0)
+                       s.lastInfo.Store(nil)
                        return schema, adbc.Partitions{}, totalRecords, nil
                }
 
@@ -628,6 +647,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) 
(*arrow.Schema, adbc.
                        s.incrementalState.previousInfo = poll.GetInfo()
                        s.incrementalState.retryDescriptor = 
poll.GetFlightDescriptor()
                        atomicStoreFloat64(&s.progress, poll.GetProgress())
+                       s.lastInfo.Store(poll.GetInfo())
 
                        if s.incrementalState.retryDescriptor == nil {
                                // Query is finished
@@ -651,6 +671,7 @@ func (s *statement) ExecutePartitions(ctx context.Context) 
(*arrow.Schema, adbc.
                if s.incrementalState.complete && len(info.Endpoint) == 0 {
                        s.incrementalState = &incrementalState{}
                        atomicStoreFloat64(&s.progress, 0.0)
+                       s.lastInfo.Store(nil)
                }
        } else if s.prepared != nil {
                info, err = s.prepared.Execute(ctx, grpc.Header(&header), 
grpc.Trailer(&trailer), s.timeouts)
diff --git a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py 
b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
index 1b9adf31..1af0c199 100644
--- a/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
+++ b/python/adbc_driver_flightsql/adbc_driver_flightsql/__init__.py
@@ -104,6 +104,15 @@ class ConnectionOptions(enum.Enum):
 class StatementOptions(enum.Enum):
     """Statement options specific to the Flight SQL driver."""
 
+    #: The latest FlightInfo value.
+    #:
+    #: Thread-safe.  Mostly useful when using incremental execution, where an
+    #: advanced client may want to inspect the latest FlightInfo from the
+    #: service, but without waiting for execute_partitions to return.  (The
+    #: service may send an updated FlightInfo with progress/app_metadata
+    #: values, but execute_partitions will only return if there are new
+    #: endpoints.)
+    LAST_FLIGHT_INFO = "adbc.flight.sql.statement.exec.last_flight_info"
     #: The number of batches to queue per partition. Defaults to 5.
     #:
     #: This controls how much we read ahead on result sets.
diff --git a/python/adbc_driver_flightsql/tests/test_incremental.py 
b/python/adbc_driver_flightsql/tests/test_incremental.py
index 8f47fd39..285c18a8 100644
--- a/python/adbc_driver_flightsql/tests/test_incremental.py
+++ b/python/adbc_driver_flightsql/tests/test_incremental.py
@@ -16,13 +16,16 @@
 # under the License.
 
 import re
+import threading
 
 import google.protobuf.any_pb2 as any_pb2
 import google.protobuf.wrappers_pb2 as wrappers_pb2
 import pyarrow
+import pyarrow.flight
 import pytest
 
 import adbc_driver_manager
+from adbc_driver_flightsql import StatementOptions as FlightSqlStatementOptions
 from adbc_driver_manager import StatementOptions
 
 SCHEMA = pyarrow.schema([("ints", "int32")])
@@ -106,6 +109,54 @@ def test_incremental_error_poll(test_dbapi) -> None:
         assert partitions == []
 
 
+def test_incremental_cancel(test_dbapi) -> None:
+    with test_dbapi.cursor() as cur:
+        assert (
+            cur.adbc_statement.get_option_bytes(
+                FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
+            )
+            == b""
+        )
+
+        cur.adbc_statement.set_options(
+            **{
+                StatementOptions.INCREMENTAL.value: "true",
+            }
+        )
+        partitions, schema = cur.adbc_execute_partitions("forever")
+        assert len(partitions) == 1
+
+        passed = False
+
+        def _bg():
+            nonlocal passed
+            while True:
+                progress = cur.adbc_statement.get_option_float(
+                    StatementOptions.PROGRESS.value
+                )
+                # XXX: upstream PyArrow never bothered exposing app_metadata
+                raw_info = cur.adbc_statement.get_option_bytes(
+                    FlightSqlStatementOptions.LAST_FLIGHT_INFO.value
+                )
+
+                # check that it's a valid info
+                pyarrow.flight.FlightInfo.deserialize(raw_info)
+                passed = b"app metadata" in raw_info
+
+                if progress > 0.07:
+                    break
+            cur.adbc_cancel()
+
+        t = threading.Thread(target=_bg, daemon=True)
+        t.start()
+
+        with pytest.raises(test_dbapi.OperationalError, match="(?i)cancelled"):
+            cur.adbc_execute_partitions("forever")
+
+        t.join()
+        assert passed
+
+
 def test_incremental_immediately(test_dbapi) -> None:
     with test_dbapi.cursor() as cur:
         cur.adbc_statement.set_options(

Reply via email to