This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new d60ff53394 GH-41427: [Go] Fix stateless prepared statements (#41428)
d60ff53394 is described below

commit d60ff53394788aef9a6070dfdf46a2bcade128ad
Author: David Li <[email protected]>
AuthorDate: Tue Apr 30 08:46:26 2024 +0900

    GH-41427: [Go] Fix stateless prepared statements (#41428)
    
    
    
    ### Rationale for this change
    
    Stateless prepared statements didn't actually work
    
    ### What changes are included in this PR?
    
    Update the handle after binding parameters
    
    ### Are these changes tested?
    
    Yes
    
    ### Are there any user-facing changes?
    
    No
    * GitHub Issue: #41427
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 go/arrow/flight/flightsql/client.go      | 93 ++++++++++++++------------------
 go/arrow/flight/flightsql/client_test.go | 10 ++--
 2 files changed, 45 insertions(+), 58 deletions(-)

diff --git a/go/arrow/flight/flightsql/client.go 
b/go/arrow/flight/flightsql/client.go
index e594191c35..c6794820dc 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -1119,24 +1119,10 @@ func (p *PreparedStatement) Execute(ctx 
context.Context, opts ...grpc.CallOption
                return nil, err
        }
 
-       if p.hasBindParameters() {
-               pstream, err := p.client.Client.DoPut(ctx, opts...)
-               if err != nil {
-                       return nil, err
-               }
-               wr, err := p.writeBindParameters(pstream, desc)
-               if err != nil {
-                       return nil, err
-               }
-               if err = wr.Close(); err != nil {
-                       return nil, err
-               }
-               pstream.CloseSend()
-               if err = p.captureDoPutPreparedStatementHandle(pstream); err != 
nil {
-                       return nil, err
-               }
+       desc, err = p.bindParameters(ctx, desc, opts...)
+       if err != nil {
+               return nil, err
        }
-
        return p.client.getFlightInfo(ctx, desc, opts...)
 }
 
@@ -1156,23 +1142,9 @@ func (p *PreparedStatement) ExecutePut(ctx 
context.Context, opts ...grpc.CallOpt
                return err
        }
 
-       if p.hasBindParameters() {
-               pstream, err := p.client.Client.DoPut(ctx, opts...)
-               if err != nil {
-                       return err
-               }
-
-               wr, err := p.writeBindParameters(pstream, desc)
-               if err != nil {
-                       return err
-               }
-               if err = wr.Close(); err != nil {
-                       return err
-               }
-               pstream.CloseSend()
-               if err = p.captureDoPutPreparedStatementHandle(pstream); err != 
nil {
-                       return err
-               }
+       _, err = p.bindParameters(ctx, desc, opts...)
+       if err != nil {
+               return err
        }
 
        return nil
@@ -1200,23 +1172,9 @@ func (p *PreparedStatement) ExecutePoll(ctx 
context.Context, retryDescriptor *fl
        }
 
        if retryDescriptor == nil {
-               if p.hasBindParameters() {
-                       pstream, err := p.client.Client.DoPut(ctx, opts...)
-                       if err != nil {
-                               return nil, err
-                       }
-
-                       wr, err := p.writeBindParameters(pstream, desc)
-                       if err != nil {
-                               return nil, err
-                       }
-                       if err = wr.Close(); err != nil {
-                               return nil, err
-                       }
-                       pstream.CloseSend()
-                       if err = 
p.captureDoPutPreparedStatementHandle(pstream); err != nil {
-                               return nil, err
-                       }
+               desc, err = p.bindParameters(ctx, desc, opts...)
+               if err != nil {
+                       return nil, err
                }
        }
        return p.client.Client.PollFlightInfo(ctx, desc, opts...)
@@ -1248,7 +1206,7 @@ func (p *PreparedStatement) ExecuteUpdate(ctx 
context.Context, opts ...grpc.Call
                return
        }
        if p.hasBindParameters() {
-               wr, err = p.writeBindParameters(pstream, desc)
+               wr, err = p.writeBindParametersToStream(pstream, desc)
                if err != nil {
                        return
                }
@@ -1283,7 +1241,36 @@ func (p *PreparedStatement) hasBindParameters() bool {
        return (p.paramBinding != nil && p.paramBinding.NumRows() > 0) || 
(p.streamBinding != nil)
 }
 
-func (p *PreparedStatement) writeBindParameters(pstream 
pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, 
error) {
+func (p *PreparedStatement) bindParameters(ctx context.Context, desc 
*pb.FlightDescriptor, opts ...grpc.CallOption) (*flight.FlightDescriptor, 
error) {
+       if p.hasBindParameters() {
+               pstream, err := p.client.Client.DoPut(ctx, opts...)
+               if err != nil {
+                       return nil, err
+               }
+               wr, err := p.writeBindParametersToStream(pstream, desc)
+               if err != nil {
+                       return nil, err
+               }
+               if err = wr.Close(); err != nil {
+                       return nil, err
+               }
+               pstream.CloseSend()
+               if err = p.captureDoPutPreparedStatementHandle(pstream); err != 
nil {
+                       return nil, err
+               }
+
+               cmd := 
pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}
+               desc, err = descForCommand(&cmd)
+               if err != nil {
+                       return nil, err
+               }
+               return desc, nil
+       }
+       return desc, nil
+}
+
+// XXX: this does not capture the updated handle. Prefer bindParameters.
+func (p *PreparedStatement) writeBindParametersToStream(pstream 
pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, 
error) {
        if p.paramBinding != nil {
                wr := flight.NewRecordWriter(pstream, 
ipc.WithSchema(p.paramBinding.Schema()))
                wr.SetFlightDescriptor(desc)
diff --git a/go/arrow/flight/flightsql/client_test.go 
b/go/arrow/flight/flightsql/client_test.go
index 727fe02aa7..33da79167c 100644
--- a/go/arrow/flight/flightsql/client_test.go
+++ b/go/arrow/flight/flightsql/client_test.go
@@ -448,9 +448,9 @@ func (s *FlightSqlClientSuite) 
TestPreparedStatementExecuteParamBinding() {
        expectedDesc := 
getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: 
[]byte(handle)})
 
        // mocked DoPut result
-    doPutPreparedStatementResult := 
&pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
+       doPutPreparedStatementResult := 
&pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
        resdata, _ := proto.Marshal(doPutPreparedStatementResult)
-       putResult := &pb.PutResult{ AppMetadata: resdata }
+       putResult := &pb.PutResult{AppMetadata: resdata}
 
        // mocked client stream for DoPut
        mockedPut := &mockDoPutClient{}
@@ -461,7 +461,7 @@ func (s *FlightSqlClientSuite) 
TestPreparedStatementExecuteParamBinding() {
        mockedPut.On("CloseSend").Return(nil)
        mockedPut.On("Recv").Return(putResult, nil)
 
-       infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: 
[]byte(handle)}
+       infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: 
[]byte(updatedHandle)}
        desc := getDesc(infoCmd)
        s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, 
s.callOpts).Return(&emptyFlightInfo, nil)
 
@@ -525,9 +525,9 @@ func (s *FlightSqlClientSuite) 
TestPreparedStatementExecuteReaderBinding() {
        expectedDesc := 
getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: 
[]byte(query)})
 
        // mocked DoPut result
-    doPutPreparedStatementResult := 
&pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
+       doPutPreparedStatementResult := 
&pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
        resdata, _ := proto.Marshal(doPutPreparedStatementResult)
-       putResult := &pb.PutResult{ AppMetadata: resdata }
+       putResult := &pb.PutResult{AppMetadata: resdata}
 
        // mocked client stream for DoPut
        mockedPut := &mockDoPutClient{}

Reply via email to