This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 143a7da103 GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
143a7da103 is described below
commit 143a7da1038c3b9b9ad9587f668ca7abcc3520f8
Author: David Li <[email protected]>
AuthorDate: Fri Jan 19 11:25:30 2024 -0500
GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
### Rationale for this change
It's impossible to use the current bindings with PollFlightInfo. Required
for apache/arrow-adbc#1457.
### What changes are included in this PR?
Add new methods that expose PollFlightInfo.
### Are these changes tested?
Yes
### Are there any user-facing changes?
Adds new methods.
* Closes: #39574
Authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
---
go/arrow/flight/flightsql/client.go | 92 ++++++++++++++++++++++++++++++++
go/arrow/flight/flightsql/server.go | 54 +++++++++++++++++++
go/arrow/flight/flightsql/server_test.go | 60 +++++++++++++++++++++
3 files changed, 206 insertions(+)
diff --git a/go/arrow/flight/flightsql/client.go
b/go/arrow/flight/flightsql/client.go
index c0c7e2cf20..89784b483b 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -82,6 +82,17 @@ func flightInfoForCommand(ctx context.Context, cl *Client,
cmd proto.Message, op
return cl.getFlightInfo(ctx, desc, opts...)
}
+func pollInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
+ if retryDescriptor != nil {
+ return cl.Client.PollFlightInfo(ctx, retryDescriptor, opts...)
+ }
+ desc, err := descForCommand(cmd)
+ if err != nil {
+ return nil, err
+ }
+ return cl.Client.PollFlightInfo(ctx, desc, opts...)
+}
+
func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts
...grpc.CallOption) (*flight.SchemaResult, error) {
desc, err := descForCommand(cmd)
if err != nil {
@@ -123,6 +134,14 @@ func (c *Client) Execute(ctx context.Context, query
string, opts ...grpc.CallOpt
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+// ExecutePoll idempotently starts execution of a query/checks for completion.
+// To check for completion, pass the FlightDescriptor from the previous call
+// to ExecutePoll as the retryDescriptor.
+func (c *Client) ExecutePoll(ctx context.Context, query string,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
+ cmd := pb.CommandStatementQuery{Query: query}
+ return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
+}
+
// GetExecuteSchema gets the schema of the result set of a query without
// executing the query itself.
func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.SchemaResult, error) {
@@ -136,6 +155,12 @@ func (c *Client) ExecuteSubstrait(ctx context.Context,
plan SubstraitPlan, opts
return flightInfoForCommand(ctx, c, &cmd, opts...)
}
+func (c *Client) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
+ cmd := pb.CommandStatementSubstraitPlan{
+ Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
+ return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
+}
+
func (c *Client) GetExecuteSubstraitSchema(ctx context.Context, plan
SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
cmd := pb.CommandStatementSubstraitPlan{
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
@@ -606,6 +631,15 @@ func (tx *Txn) Execute(ctx context.Context, query string,
opts ...grpc.CallOptio
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
}
+func (tx *Txn) ExecutePoll(ctx context.Context, query string, retryDescriptor
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+ if !tx.txn.IsValid() {
+ return nil, ErrInvalidTxn
+ }
+ // The server should encode the transaction into the retry descriptor
+ cmd := &pb.CommandStatementQuery{Query: query, TransactionId: tx.txn}
+ return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
+}
+
func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts
...grpc.CallOption) (*flight.FlightInfo, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
@@ -616,6 +650,18 @@ func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan
SubstraitPlan, opts ..
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
}
+func (tx *Txn) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan,
retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption)
(*flight.PollInfo, error) {
+ if !tx.txn.IsValid() {
+ return nil, ErrInvalidTxn
+ }
+ // The server should encode the transaction into the retry descriptor
+ cmd := &pb.CommandStatementSubstraitPlan{
+ Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version:
plan.Version},
+ TransactionId: tx.txn,
+ }
+ return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
+}
+
func (tx *Txn) GetExecuteSchema(ctx context.Context, query string, opts
...grpc.CallOption) (*flight.SchemaResult, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
@@ -981,6 +1027,52 @@ func (p *PreparedStatement) Execute(ctx context.Context,
opts ...grpc.CallOption
return p.client.getFlightInfo(ctx, desc, opts...)
}
+// ExecutePoll executes the prepared statement on the server and returns a
PollInfo
+// indicating the progress of execution.
+//
+// Will error if already closed.
+func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor
*flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
+ if p.closed {
+ return nil, errors.New("arrow/flightsql: prepared statement
already closed")
+ }
+
+ cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle:
p.handle}
+
+ desc := retryDescriptor
+ var err error
+
+ if desc == nil {
+ desc, err = descForCommand(cmd)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ if retryDescriptor == nil {
+ if p.hasBindParameters() {
+ pstream, err := p.client.Client.DoPut(ctx, opts...)
+ if err != nil {
+ return nil, err
+ }
+
+ wr, err := p.writeBindParameters(pstream, desc)
+ if err != nil {
+ return nil, err
+ }
+ if err = wr.Close(); err != nil {
+ return nil, err
+ }
+ pstream.CloseSend()
+
+ // wait for the server to ack the result
+ if _, err = pstream.Recv(); err != nil && err != io.EOF
{
+ return nil, err
+ }
+ }
+ }
+ return p.client.Client.PollFlightInfo(ctx, desc, opts...)
+}
+
// ExecuteUpdate executes the prepared statement update query on the server
// and returns the number of rows affected. If SetParameters was called,
// the parameter bindings will be sent with the request to execute.
diff --git a/go/arrow/flight/flightsql/server.go
b/go/arrow/flight/flightsql/server.go
index 5b1764707c..2ec02e2829 100644
--- a/go/arrow/flight/flightsql/server.go
+++ b/go/arrow/flight/flightsql/server.go
@@ -524,6 +524,22 @@ func (BaseServer) RenewFlightEndpoint(context.Context,
*flight.RenewFlightEndpoi
return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not
implemented")
}
+func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor)
(*flight.PollInfo, error) {
+ return nil, status.Error(codes.Unimplemented, "PollFlightInfo not
implemented")
+}
+
+func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery,
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+ return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement
not implemented")
+}
+
+func (BaseServer) PollFlightInfoSubstraitPlan(context.Context,
StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) {
+ return nil, status.Error(codes.Unimplemented,
"PollFlightInfoSubstraitPlan not implemented")
+}
+
+func (BaseServer) PollFlightInfoPreparedStatement(context.Context,
PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
+ return nil, status.Error(codes.Unimplemented,
"PollFlightInfoPreparedStatement not implemented")
+}
+
func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest)
error {
return status.Error(codes.Unimplemented, "EndTransaction not
implemented")
}
@@ -652,6 +668,14 @@ type Server interface {
CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest)
(flight.CancelFlightInfoResult, error)
// RenewFlightEndpoint attempts to extend the expiration of a
FlightEndpoint
RenewFlightEndpoint(context.Context,
*flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error)
+ // PollFlightInfo is a generic handler for PollFlightInfo requests.
+ PollFlightInfo(context.Context, *flight.FlightDescriptor)
(*flight.PollInfo, error)
+ // PollFlightInfoStatement handles polling for query execution.
+ PollFlightInfoStatement(context.Context, StatementQuery,
*flight.FlightDescriptor) (*flight.PollInfo, error)
+ // PollFlightInfoSubstraitPlan handles polling for query execution.
+ PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan,
*flight.FlightDescriptor) (*flight.PollInfo, error)
+ // PollFlightInfoPreparedStatement handles polling for query execution.
+ PollFlightInfoPreparedStatement(context.Context,
PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
mustEmbedBaseServer()
}
@@ -729,6 +753,36 @@ func (f *flightSqlServer) GetFlightInfo(ctx
context.Context, request *flight.Fli
return nil, status.Error(codes.InvalidArgument, "requested command is
invalid")
}
+func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+ var (
+ anycmd anypb.Any
+ cmd proto.Message
+ err error
+ )
+ // If we can't parse things, be friendly and defer to the server
+ // implementation. This is especially important for this method since
+ // the server returns a custom FlightDescriptor for future requests.
+ if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil {
+ return f.srv.PollFlightInfo(ctx, request)
+ }
+
+ if cmd, err = anycmd.UnmarshalNew(); err != nil {
+ return f.srv.PollFlightInfo(ctx, request)
+ }
+
+ switch cmd := cmd.(type) {
+ case *pb.CommandStatementQuery:
+ return f.srv.PollFlightInfoStatement(ctx, cmd, request)
+ case *pb.CommandStatementSubstraitPlan:
+ return f.srv.PollFlightInfoSubstraitPlan(ctx,
&statementSubstraitPlan{cmd}, request)
+ case *pb.CommandPreparedStatementQuery:
+ return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request)
+ }
+ // XXX: for now we won't support the other methods
+
+ return f.srv.PollFlightInfo(ctx, request)
+}
+
func (f *flightSqlServer) GetSchema(ctx context.Context, request
*flight.FlightDescriptor) (*flight.SchemaResult, error) {
var (
anycmd anypb.Any
diff --git a/go/arrow/flight/flightsql/server_test.go
b/go/arrow/flight/flightsql/server_test.go
index e444da4aaf..956a1714c6 100644
--- a/go/arrow/flight/flightsql/server_test.go
+++ b/go/arrow/flight/flightsql/server_test.go
@@ -56,6 +56,36 @@ func (*testServer) GetFlightInfoStatement(ctx
context.Context, q flightsql.State
}, nil
}
+func (*testServer) PollFlightInfo(ctx context.Context, fd
*flight.FlightDescriptor) (*flight.PollInfo, error) {
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ FlightDescriptor: fd,
+ Endpoint: []*flight.FlightEndpoint{{
+ Ticket: &flight.Ticket{Ticket: []byte{}},
+ }, {
+ Ticket: &flight.Ticket{Ticket: []byte{}},
+ }},
+ },
+ FlightDescriptor: nil,
+ }, nil
+}
+
+func (*testServer) PollFlightInfoStatement(ctx context.Context, q
flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.PollInfo,
error) {
+ ticket, err :=
flightsql.CreateStatementQueryTicket([]byte(q.GetQuery()))
+ if err != nil {
+ return nil, err
+ }
+ return &flight.PollInfo{
+ Info: &flight.FlightInfo{
+ FlightDescriptor: fd,
+ Endpoint: []*flight.FlightEndpoint{{
+ Ticket: &flight.Ticket{Ticket: ticket},
+ }},
+ },
+ FlightDescriptor: &flight.FlightDescriptor{Cmd: []byte{}},
+ }, nil
+}
+
func (*testServer) DoGetStatement(ctx context.Context, ticket
flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan
flight.StreamChunk, err error) {
handle := string(ticket.GetStatementHandle())
switch handle {
@@ -189,6 +219,20 @@ func (s *FlightSqlServerSuite) TestExecuteChunkError() {
}
}
+func (s *FlightSqlServerSuite) TestExecutePoll() {
+ poll, err := s.cl.ExecutePoll(context.TODO(), "1", nil)
+ s.NoError(err)
+ s.NotNil(poll)
+ s.NotNil(poll.GetFlightDescriptor())
+ s.Len(poll.GetInfo().Endpoint, 1)
+
+ poll, err = s.cl.ExecutePoll(context.TODO(), "1",
poll.GetFlightDescriptor())
+ s.NoError(err)
+ s.NotNil(poll)
+ s.Nil(poll.GetFlightDescriptor())
+ s.Len(poll.GetInfo().Endpoint, 2)
+}
+
type UnimplementedFlightSqlServerSuite struct {
suite.Suite
@@ -314,6 +358,22 @@ func (s *UnimplementedFlightSqlServerSuite)
TestGetTypeInfo() {
s.Nil(info)
}
+func (s *UnimplementedFlightSqlServerSuite) TestPoll() {
+ poll, err := s.cl.ExecutePoll(context.TODO(), "", nil)
+ st, ok := status.FromError(err)
+ s.True(ok)
+ s.Equal(codes.Unimplemented, st.Code())
+ s.Equal("PollFlightInfoStatement not implemented", st.Message())
+ s.Nil(poll)
+
+ poll, err = s.cl.ExecuteSubstraitPoll(context.TODO(),
flightsql.SubstraitPlan{}, nil)
+ st, ok = status.FromError(err)
+ s.True(ok)
+ s.Equal(codes.Unimplemented, st.Code())
+ s.Equal("PollFlightInfoSubstraitPlan not implemented", st.Message())
+ s.Nil(poll)
+}
+
func getTicket(cmd proto.Message) *flight.Ticket {
var anycmd anypb.Any
anycmd.MarshalFrom(cmd)