This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 27066c1f9f GH-35240: [Go][FlightRPC] Fix crash in client middleware
(#35241)
27066c1f9f is described below
commit 27066c1f9f6042748ec111df2bf6d8795c3d7d9a
Author: David Li <[email protected]>
AuthorDate: Fri Apr 21 07:40:29 2023 +0900
GH-35240: [Go][FlightRPC] Fix crash in client middleware (#35241)
### Rationale for this change
The Go interceptor API includes a provision for errors from trying to start
an RPC. The handler for this in the Flight code was trying to use a nil pointer
as a result.
### What changes are included in this PR?
Fix a crash when those errors are encountered.
### Are these changes tested?
New tests were added.
### Are there any user-facing changes?
There are no user-facing changes.
* Closes: #35240
Authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
---
go/arrow/flight/client.go | 4 --
go/arrow/flight/flight_middleware_test.go | 64 +++++++++++++++++++++++++++++++
go/arrow/flight/flight_test.go | 5 ++-
3 files changed, 68 insertions(+), 5 deletions(-)
diff --git a/go/arrow/flight/client.go b/go/arrow/flight/client.go
index da6b60c89b..b7287ec32a 100644
--- a/go/arrow/flight/client.go
+++ b/go/arrow/flight/client.go
@@ -120,10 +120,6 @@ func CreateClientMiddleware(middleware
CustomClientMiddleware) ClientMiddleware
}
if err != nil {
- if isHdrs {
- md, _ := cs.Header()
- hdrs.HeadersReceived(ctx,
metadata.Join(md, cs.Trailer()))
- }
if isPostcall {
post.CallCompleted(ctx, err)
}
diff --git a/go/arrow/flight/flight_middleware_test.go
b/go/arrow/flight/flight_middleware_test.go
index c75d5091ca..1ff9a6bc02 100755
--- a/go/arrow/flight/flight_middleware_test.go
+++ b/go/arrow/flight/flight_middleware_test.go
@@ -55,6 +55,17 @@ func (s *ServerMiddlewareAddHeader) CallCompleted(ctx
context.Context, err error
}
}
+type ServerMiddlewareAddHeaderError struct{}
+
+func (s *ServerMiddlewareAddHeaderError) StartCall(ctx context.Context)
context.Context {
+ grpc.SetHeader(ctx, metadata.Pairs("foo", "bar"))
+ return nil
+}
+
+func (s *ServerMiddlewareAddHeaderError) CallCompleted(ctx context.Context,
err error) {
+ grpc.SetTrailer(ctx, metadata.Pairs("super", "duper"))
+}
+
type ServerTraceMiddleware struct{}
type tracetestKey struct{}
@@ -252,6 +263,33 @@ func TestClientStreamMiddleware(t *testing.T) {
assert.Equal(t, []string{"duper"}, middleware.md.Get("super"))
}
+func TestClientStreamMiddlewareWithError(t *testing.T) {
+ s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
+ })
+ s.Init("localhost:0")
+ f := &flightServer{}
+ s.RegisterFlightService(f)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ middle := &ClientTestSendHeaderMiddleware{}
+ client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
[]flight.ClientMiddleware{
+ flight.CreateClientMiddleware(middle),
+ }, grpc.WithTransportCredentials(insecure.NewCredentials()))
+
+ require.NoError(t, err)
+ defer client.Close()
+
+ // UseCompressor triggers a particular rare failure path.
+ _, err = client.DoGet(context.Background(), &flight.Ticket{Ticket:
[]byte("this flight does not exist")}, grpc.UseCompressor("foo"))
+ if err == nil {
+ t.Fatal("Expected error but got nothing")
+ }
+ assert.Contains(t, err.Error(), "Compressor is not installed")
+}
+
func TestClientUnaryMiddleware(t *testing.T) {
s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
flight.CreateServerMiddleware(&ServerMiddlewareAddHeader{}),
@@ -295,3 +333,29 @@ func TestClientUnaryMiddleware(t *testing.T) {
})
}
}
+
+func TestClientUnaryMiddlewareWithError(t *testing.T) {
+ s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+
flight.CreateServerMiddleware(&ServerMiddlewareAddHeaderError{}),
+ })
+ s.Init("localhost:0")
+ f := &flightServer{}
+ s.RegisterFlightService(f)
+
+ go s.Serve()
+ defer s.Shutdown()
+
+ middle := &ClientTestSendHeaderMiddleware{}
+ client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
[]flight.ClientMiddleware{
+ flight.CreateClientMiddleware(middle),
+ }, grpc.WithTransportCredentials(insecure.NewCredentials()))
+
+ require.NoError(t, err)
+ defer client.Close()
+
+ _, err = client.GetSchema(context.Background(),
&flight.FlightDescriptor{Path: []string{"this flight does not exist"}},
grpc.UseCompressor("foo"))
+ if err == nil {
+ t.Fatal("Expected error but got nothing")
+ }
+ assert.Contains(t, err.Error(), "Compressor is not installed")
+}
diff --git a/go/arrow/flight/flight_test.go b/go/arrow/flight/flight_test.go
index cd682ffefa..f8585df662 100755
--- a/go/arrow/flight/flight_test.go
+++ b/go/arrow/flight/flight_test.go
@@ -98,7 +98,10 @@ func (f *flightServer) GetSchema(_ context.Context, in
*flight.FlightDescriptor)
}
func (f *flightServer) DoGet(tkt *flight.Ticket, fs
flight.FlightService_DoGetServer) error {
- recs := arrdata.Records[string(tkt.GetTicket())]
+ recs, ok := arrdata.Records[string(tkt.GetTicket())]
+ if !ok {
+ return status.Error(codes.NotFound, "flight not found")
+ }
w := flight.NewRecordWriter(fs, ipc.WithSchema(recs[0].Schema()))
for _, r := range recs {