This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 143a7da103 GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
143a7da103 is described below

commit 143a7da1038c3b9b9ad9587f668ca7abcc3520f8
Author: David Li <[email protected]>
AuthorDate: Fri Jan 19 11:25:30 2024 -0500

    GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
    
    
    
    ### Rationale for this change
    
    It's impossible to use the current bindings with PollFlightInfo. Required 
for apache/arrow-adbc#1457.
    
    ### What changes are included in this PR?
    
    Add new methods that expose PollFlightInfo.
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    Adds new methods.
    * Closes: #39574
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 go/arrow/flight/flightsql/client.go      | 92 ++++++++++++++++++++++++++++++++
 go/arrow/flight/flightsql/server.go      | 54 +++++++++++++++++++
 go/arrow/flight/flightsql/server_test.go | 60 +++++++++++++++++++++
 3 files changed, 206 insertions(+)

diff --git a/go/arrow/flight/flightsql/client.go 
b/go/arrow/flight/flightsql/client.go
index c0c7e2cf20..89784b483b 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -82,6 +82,17 @@ func flightInfoForCommand(ctx context.Context, cl *Client, 
cmd proto.Message, op
        return cl.getFlightInfo(ctx, desc, opts...)
 }
 
+func pollInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, 
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) 
(*flight.PollInfo, error) {
+       if retryDescriptor != nil {
+               return cl.Client.PollFlightInfo(ctx, retryDescriptor, opts...)
+       }
+       desc, err := descForCommand(cmd)
+       if err != nil {
+               return nil, err
+       }
+       return cl.Client.PollFlightInfo(ctx, desc, opts...)
+}
+
 func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts 
...grpc.CallOption) (*flight.SchemaResult, error) {
        desc, err := descForCommand(cmd)
        if err != nil {
@@ -123,6 +134,14 @@ func (c *Client) Execute(ctx context.Context, query 
string, opts ...grpc.CallOpt
        return flightInfoForCommand(ctx, c, &cmd, opts...)
 }
 
+// ExecutePoll idempotently starts execution of a query/checks for completion.
+// To check for completion, pass the FlightDescriptor from the previous call
+// to ExecutePoll as the retryDescriptor.
+func (c *Client) ExecutePoll(ctx context.Context, query string, 
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) 
(*flight.PollInfo, error) {
+       cmd := pb.CommandStatementQuery{Query: query}
+       return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
+}
+
 // GetExecuteSchema gets the schema of the result set of a query without
 // executing the query itself.
 func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts 
...grpc.CallOption) (*flight.SchemaResult, error) {
@@ -136,6 +155,12 @@ func (c *Client) ExecuteSubstrait(ctx context.Context, 
plan SubstraitPlan, opts
        return flightInfoForCommand(ctx, c, &cmd, opts...)
 }
 
+func (c *Client) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, 
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) 
(*flight.PollInfo, error) {
+       cmd := pb.CommandStatementSubstraitPlan{
+               Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
+       return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
+}
+
 func (c *Client) GetExecuteSubstraitSchema(ctx context.Context, plan 
SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
        cmd := pb.CommandStatementSubstraitPlan{
                Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
@@ -606,6 +631,15 @@ func (tx *Txn) Execute(ctx context.Context, query string, 
opts ...grpc.CallOptio
        return flightInfoForCommand(ctx, tx.c, cmd, opts...)
 }
 
+func (tx *Txn) ExecutePoll(ctx context.Context, query string, retryDescriptor 
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+       if !tx.txn.IsValid() {
+               return nil, ErrInvalidTxn
+       }
+       // The server should encode the transaction into the retry descriptor
+       cmd := &pb.CommandStatementQuery{Query: query, TransactionId: tx.txn}
+       return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
+}
+
 func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts 
...grpc.CallOption) (*flight.FlightInfo, error) {
        if !tx.txn.IsValid() {
                return nil, ErrInvalidTxn
@@ -616,6 +650,18 @@ func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan 
SubstraitPlan, opts ..
        return flightInfoForCommand(ctx, tx.c, cmd, opts...)
 }
 
+func (tx *Txn) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, 
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) 
(*flight.PollInfo, error) {
+       if !tx.txn.IsValid() {
+               return nil, ErrInvalidTxn
+       }
+       // The server should encode the transaction into the retry descriptor
+       cmd := &pb.CommandStatementSubstraitPlan{
+               Plan:          &pb.SubstraitPlan{Plan: plan.Plan, Version: 
plan.Version},
+               TransactionId: tx.txn,
+       }
+       return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
+}
+
 func (tx *Txn) GetExecuteSchema(ctx context.Context, query string, opts 
...grpc.CallOption) (*flight.SchemaResult, error) {
        if !tx.txn.IsValid() {
                return nil, ErrInvalidTxn
@@ -981,6 +1027,52 @@ func (p *PreparedStatement) Execute(ctx context.Context, 
opts ...grpc.CallOption
        return p.client.getFlightInfo(ctx, desc, opts...)
 }
 
+// ExecutePoll executes the prepared statement on the server and returns a 
PollInfo
+// indicating the progress of execution.
+//
+// Will error if already closed.
+func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor 
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+       if p.closed {
+               return nil, errors.New("arrow/flightsql: prepared statement 
already closed")
+       }
+
+       cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: 
p.handle}
+
+       desc := retryDescriptor
+       var err error
+
+       if desc == nil {
+               desc, err = descForCommand(cmd)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       if retryDescriptor == nil {
+               if p.hasBindParameters() {
+                       pstream, err := p.client.Client.DoPut(ctx, opts...)
+                       if err != nil {
+                               return nil, err
+                       }
+
+                       wr, err := p.writeBindParameters(pstream, desc)
+                       if err != nil {
+                               return nil, err
+                       }
+                       if err = wr.Close(); err != nil {
+                               return nil, err
+                       }
+                       pstream.CloseSend()
+
+                       // wait for the server to ack the result
+                       if _, err = pstream.Recv(); err != nil && err != io.EOF 
{
+                               return nil, err
+                       }
+               }
+       }
+       return p.client.Client.PollFlightInfo(ctx, desc, opts...)
+}
+
 // ExecuteUpdate executes the prepared statement update query on the server
 // and returns the number of rows affected. If SetParameters was called,
 // the parameter bindings will be sent with the request to execute.
diff --git a/go/arrow/flight/flightsql/server.go 
b/go/arrow/flight/flightsql/server.go
index 5b1764707c..2ec02e2829 100644
--- a/go/arrow/flight/flightsql/server.go
+++ b/go/arrow/flight/flightsql/server.go
@@ -524,6 +524,22 @@ func (BaseServer) RenewFlightEndpoint(context.Context, 
*flight.RenewFlightEndpoi
        return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not 
implemented")
 }
 
+func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor) 
(*flight.PollInfo, error) {
+       return nil, status.Error(codes.Unimplemented, "PollFlightInfo not 
implemented")
+}
+
+func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery, 
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+       return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement 
not implemented")
+}
+
+func (BaseServer) PollFlightInfoSubstraitPlan(context.Context, 
StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) {
+       return nil, status.Error(codes.Unimplemented, 
"PollFlightInfoSubstraitPlan not implemented")
+}
+
+func (BaseServer) PollFlightInfoPreparedStatement(context.Context, 
PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
+       return nil, status.Error(codes.Unimplemented, 
"PollFlightInfoPreparedStatement not implemented")
+}
+
 func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest) 
error {
        return status.Error(codes.Unimplemented, "EndTransaction not 
implemented")
 }
@@ -652,6 +668,14 @@ type Server interface {
        CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) 
(flight.CancelFlightInfoResult, error)
        // RenewFlightEndpoint attempts to extend the expiration of a 
FlightEndpoint
        RenewFlightEndpoint(context.Context, 
*flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error)
+       // PollFlightInfo is a generic handler for PollFlightInfo requests.
+       PollFlightInfo(context.Context, *flight.FlightDescriptor) 
(*flight.PollInfo, error)
+       // PollFlightInfoStatement handles polling for query execution.
+       PollFlightInfoStatement(context.Context, StatementQuery, 
*flight.FlightDescriptor) (*flight.PollInfo, error)
+       // PollFlightInfoSubstraitPlan handles polling for query execution.
+       PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, 
*flight.FlightDescriptor) (*flight.PollInfo, error)
+       // PollFlightInfoPreparedStatement handles polling for query execution.
+       PollFlightInfoPreparedStatement(context.Context, 
PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
 
        mustEmbedBaseServer()
 }
@@ -729,6 +753,36 @@ func (f *flightSqlServer) GetFlightInfo(ctx 
context.Context, request *flight.Fli
        return nil, status.Error(codes.InvalidArgument, "requested command is 
invalid")
 }
 
+func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request 
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+       var (
+               anycmd anypb.Any
+               cmd    proto.Message
+               err    error
+       )
+       // If we can't parse things, be friendly and defer to the server
+       // implementation. This is especially important for this method since
+       // the server returns a custom FlightDescriptor for future requests.
+       if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil {
+               return f.srv.PollFlightInfo(ctx, request)
+       }
+
+       if cmd, err = anycmd.UnmarshalNew(); err != nil {
+               return f.srv.PollFlightInfo(ctx, request)
+       }
+
+       switch cmd := cmd.(type) {
+       case *pb.CommandStatementQuery:
+               return f.srv.PollFlightInfoStatement(ctx, cmd, request)
+       case *pb.CommandStatementSubstraitPlan:
+               return f.srv.PollFlightInfoSubstraitPlan(ctx, 
&statementSubstraitPlan{cmd}, request)
+       case *pb.CommandPreparedStatementQuery:
+               return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request)
+       }
+       // XXX: for now we won't support the other methods
+
+       return f.srv.PollFlightInfo(ctx, request)
+}
+
 func (f *flightSqlServer) GetSchema(ctx context.Context, request 
*flight.FlightDescriptor) (*flight.SchemaResult, error) {
        var (
                anycmd anypb.Any
diff --git a/go/arrow/flight/flightsql/server_test.go 
b/go/arrow/flight/flightsql/server_test.go
index e444da4aaf..956a1714c6 100644
--- a/go/arrow/flight/flightsql/server_test.go
+++ b/go/arrow/flight/flightsql/server_test.go
@@ -56,6 +56,36 @@ func (*testServer) GetFlightInfoStatement(ctx 
context.Context, q flightsql.State
        }, nil
 }
 
+func (*testServer) PollFlightInfo(ctx context.Context, fd 
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+       return &flight.PollInfo{
+               Info: &flight.FlightInfo{
+                       FlightDescriptor: fd,
+                       Endpoint: []*flight.FlightEndpoint{{
+                               Ticket: &flight.Ticket{Ticket: []byte{}},
+                       }, {
+                               Ticket: &flight.Ticket{Ticket: []byte{}},
+                       }},
+               },
+               FlightDescriptor: nil,
+       }, nil
+}
+
+func (*testServer) PollFlightInfoStatement(ctx context.Context, q 
flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.PollInfo, 
error) {
+       ticket, err := 
flightsql.CreateStatementQueryTicket([]byte(q.GetQuery()))
+       if err != nil {
+               return nil, err
+       }
+       return &flight.PollInfo{
+               Info: &flight.FlightInfo{
+                       FlightDescriptor: fd,
+                       Endpoint: []*flight.FlightEndpoint{{
+                               Ticket: &flight.Ticket{Ticket: ticket},
+                       }},
+               },
+               FlightDescriptor: &flight.FlightDescriptor{Cmd: []byte{}},
+       }, nil
+}
+
 func (*testServer) DoGetStatement(ctx context.Context, ticket 
flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan 
flight.StreamChunk, err error) {
        handle := string(ticket.GetStatementHandle())
        switch handle {
@@ -189,6 +219,20 @@ func (s *FlightSqlServerSuite) TestExecuteChunkError() {
        }
 }
 
+func (s *FlightSqlServerSuite) TestExecutePoll() {
+       poll, err := s.cl.ExecutePoll(context.TODO(), "1", nil)
+       s.NoError(err)
+       s.NotNil(poll)
+       s.NotNil(poll.GetFlightDescriptor())
+       s.Len(poll.GetInfo().Endpoint, 1)
+
+       poll, err = s.cl.ExecutePoll(context.TODO(), "1", 
poll.GetFlightDescriptor())
+       s.NoError(err)
+       s.NotNil(poll)
+       s.Nil(poll.GetFlightDescriptor())
+       s.Len(poll.GetInfo().Endpoint, 2)
+}
+
 type UnimplementedFlightSqlServerSuite struct {
        suite.Suite
 
@@ -314,6 +358,22 @@ func (s *UnimplementedFlightSqlServerSuite) 
TestGetTypeInfo() {
        s.Nil(info)
 }
 
+func (s *UnimplementedFlightSqlServerSuite) TestPoll() {
+       poll, err := s.cl.ExecutePoll(context.TODO(), "", nil)
+       st, ok := status.FromError(err)
+       s.True(ok)
+       s.Equal(codes.Unimplemented, st.Code())
+       s.Equal("PollFlightInfoStatement not implemented", st.Message())
+       s.Nil(poll)
+
+       poll, err = s.cl.ExecuteSubstraitPoll(context.TODO(), 
flightsql.SubstraitPlan{}, nil)
+       st, ok = status.FromError(err)
+       s.True(ok)
+       s.Equal(codes.Unimplemented, st.Code())
+       s.Equal("PollFlightInfoSubstraitPlan not implemented", st.Message())
+       s.Nil(poll)
+}
+
 func getTicket(cmd proto.Message) *flight.Ticket {
        var anycmd anypb.Any
        anycmd.MarshalFrom(cmd)

Reply via email to