This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27066c1f9f GH-35240: [Go][FlightRPC] Fix crash in client middleware 
(#35241)
27066c1f9f is described below

commit 27066c1f9f6042748ec111df2bf6d8795c3d7d9a
Author: David Li <[email protected]>
AuthorDate: Fri Apr 21 07:40:29 2023 +0900

    GH-35240: [Go][FlightRPC] Fix crash in client middleware (#35241)
    
    
    
    ### Rationale for this change
    
    The Go interceptor API includes a provision for errors from trying to start 
an RPC. The handler for this in the Flight code was trying to use a nil pointer 
as a result.
    
    ### What changes are included in this PR?
    
    Fix a crash when those errors are encountered.
    
    ### Are these changes tested?
    
    New tests were added.
    
    ### Are there any user-facing changes?
    
    There are no user-facing changes.
    * Closes: #35240
    
    Authored-by: David Li <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 go/arrow/flight/client.go                 |  4 --
 go/arrow/flight/flight_middleware_test.go | 64 +++++++++++++++++++++++++++++++
 go/arrow/flight/flight_test.go            |  5 ++-
 3 files changed, 68 insertions(+), 5 deletions(-)

diff --git a/go/arrow/flight/client.go b/go/arrow/flight/client.go
index da6b60c89b..b7287ec32a 100644
--- a/go/arrow/flight/client.go
+++ b/go/arrow/flight/client.go
@@ -120,10 +120,6 @@ func CreateClientMiddleware(middleware 
CustomClientMiddleware) ClientMiddleware
                        }
 
                        if err != nil {
-                               if isHdrs {
-                                       md, _ := cs.Header()
-                                       hdrs.HeadersReceived(ctx, 
metadata.Join(md, cs.Trailer()))
-                               }
                                if isPostcall {
                                        post.CallCompleted(ctx, err)
                                }
diff --git a/go/arrow/flight/flight_middleware_test.go 
b/go/arrow/flight/flight_middleware_test.go
index c75d5091ca..1ff9a6bc02 100755
--- a/go/arrow/flight/flight_middleware_test.go
+++ b/go/arrow/flight/flight_middleware_test.go
@@ -55,6 +55,17 @@ func (s *ServerMiddlewareAddHeader) CallCompleted(ctx 
context.Context, err error
        }
 }
 
+type ServerMiddlewareAddHeaderError struct{}
+
+func (s *ServerMiddlewareAddHeaderError) StartCall(ctx context.Context) 
context.Context {
+       grpc.SetHeader(ctx, metadata.Pairs("foo", "bar"))
+       return nil
+}
+
+func (s *ServerMiddlewareAddHeaderError) CallCompleted(ctx context.Context, 
err error) {
+       grpc.SetTrailer(ctx, metadata.Pairs("super", "duper"))
+}
+
 type ServerTraceMiddleware struct{}
 
 type tracetestKey struct{}
@@ -252,6 +263,33 @@ func TestClientStreamMiddleware(t *testing.T) {
        assert.Equal(t, []string{"duper"}, middleware.md.Get("super"))
 }
 
+func TestClientStreamMiddlewareWithError(t *testing.T) {
+       s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+               
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
+       })
+       s.Init("localhost:0")
+       f := &flightServer{}
+       s.RegisterFlightService(f)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       middle := &ClientTestSendHeaderMiddleware{}
+       client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, 
[]flight.ClientMiddleware{
+               flight.CreateClientMiddleware(middle),
+       }, grpc.WithTransportCredentials(insecure.NewCredentials()))
+
+       require.NoError(t, err)
+       defer client.Close()
+
+       // UseCompressor triggers a particular rare failure path.
+       _, err = client.DoGet(context.Background(), &flight.Ticket{Ticket: 
[]byte("this flight does not exist")}, grpc.UseCompressor("foo"))
+       if err == nil {
+               t.Fatal("Expected error but got nothing")
+       }
+       assert.Contains(t, err.Error(), "Compressor is not installed")
+}
+
 func TestClientUnaryMiddleware(t *testing.T) {
        s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
                flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
@@ -295,3 +333,29 @@ func TestClientUnaryMiddleware(t *testing.T) {
                })
        }
 }
+
+func TestClientUnaryMiddlewareWithError(t *testing.T) {
+       s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+               
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
+       })
+       s.Init("localhost:0")
+       f := &flightServer{}
+       s.RegisterFlightService(f)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       middle := &ClientTestSendHeaderMiddleware{}
+       client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil, 
[]flight.ClientMiddleware{
+               flight.CreateClientMiddleware(middle),
+       }, grpc.WithTransportCredentials(insecure.NewCredentials()))
+
+       require.NoError(t, err)
+       defer client.Close()
+
+       _, err = client.GetSchema(context.Background(), 
&flight.FlightDescriptor{Path: []string{"this flight does not exist"}}, 
grpc.UseCompressor("foo"))
+       if err == nil {
+               t.Fatal("Expected error but got nothing")
+       }
+       assert.Contains(t, err.Error(), "Compressor is not installed")
+}
diff --git a/go/arrow/flight/flight_test.go b/go/arrow/flight/flight_test.go
index cd682ffefa..f8585df662 100755
--- a/go/arrow/flight/flight_test.go
+++ b/go/arrow/flight/flight_test.go
@@ -98,7 +98,10 @@ func (f *flightServer) GetSchema(_ context.Context, in 
*flight.FlightDescriptor)
 }
 
 func (f *flightServer) DoGet(tkt *flight.Ticket, fs 
flight.FlightService_DoGetServer) error {
-       recs := arrdata.Records[string(tkt.GetTicket())]
+       recs, ok := arrdata.Records[string(tkt.GetTicket())]
+       if !ok {
+               return status.Error(codes.NotFound, "flight not found")
+       }
 
        w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
        for _, r := range recs {

Reply via email to