This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b4ac585ecb GH-36070: [Go][Flight] Add Flight Client Cookie Middleware 
(#36071)
b4ac585ecb is described below

commit b4ac585ecb4da610cc64e346e564ca86594aec53
Author: Matt Topol <[email protected]>
AuthorDate: Wed Jun 14 15:49:38 2023 -0400

    GH-36070: [Go][Flight] Add Flight Client Cookie Middleware (#36071)
    
    
    
    ### Rationale for this change
    See https://github.com/apache/arrow-adbc/issues/716
    
    ### What changes are included in this PR?
    `NewClientCookieMiddleware` function is added to the Flight package which 
returns a `ClientMiddleware` which can be used with flight and flightsql 
clients.
    
    ### Are these changes tested?
    Yes.
    
    ### Are there any user-facing changes?
    No.
    
    * Closes: #36070
    
    Authored-by: Matt Topol <[email protected]>
    Signed-off-by: Matt Topol <[email protected]>
---
 go/arrow/flight/cookie_middleware.go      | 122 +++++++++++++++
 go/arrow/flight/cookie_middleware_test.go | 241 ++++++++++++++++++++++++++++++
 2 files changed, 363 insertions(+)

diff --git a/go/arrow/flight/cookie_middleware.go 
b/go/arrow/flight/cookie_middleware.go
new file mode 100644
index 0000000000..27754a13b8
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware.go
@@ -0,0 +1,122 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight
+
+import (
+       "context"
+       "net/http"
+       "strings"
+       "sync"
+       "time"
+
+       "google.golang.org/grpc/metadata"
+)
+
+// endOfTime is the time when session (non-persistent) cookies expire.
+// This instant is representable in most date/time formats (not just
+// Go's time.Time) and should be far enough in the future.
+// taken from Go's net/http/cookiejar/jar.go
+var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
+
+// NewClientCookieMiddleware returns a go-routine safe middleware for flight
+// clients which properly handles Set-Cookie headers to store cookies
+// in a cookie jar, and then requests are sent with those cookies added
+// as a Cookie header.
+func NewClientCookieMiddleware() ClientMiddleware {
+       return CreateClientMiddleware(&clientCookieMiddleware{jar: 
make(map[string]http.Cookie)})
+}
+
+type clientCookieMiddleware struct {
+       jar map[string]http.Cookie
+       mx  sync.Mutex
+}
+
+func (cc *clientCookieMiddleware) StartCall(ctx context.Context) 
context.Context {
+       cc.mx.Lock()
+       defer cc.mx.Unlock()
+
+       if len(cc.jar) == 0 {
+               return ctx
+       }
+
+       now := time.Now()
+
+       // Per RFC 6265 section 5.4, rather than adding multiple cookie strings
+       // or multiple cookie headers, multiple cookies are all sent as a single
+       // header value separated by semicolons.
+
+       // we will also clear any expired cookies from the jar while we 
determine
+       // the cookies to send.
+       cookies := make([]string, 0, len(cc.jar))
+       for id, c := range cc.jar {
+               if !c.Expires.After(now) {
+                       delete(cc.jar, id)
+                       continue
+               }
+
+               cookies = append(cookies, (&http.Cookie{Name: c.Name, Value: 
c.Value}).String())
+       }
+
+       if len(cookies) == 0 {
+               return ctx
+       }
+
+       return metadata.AppendToOutgoingContext(ctx, "Cookie", 
strings.Join(cookies, ";"))
+}
+
+func processCookieExpire(c *http.Cookie, now time.Time) (remove bool) {
+       // MaxAge takes precedence over Expires
+       if c.MaxAge < 0 {
+               return true
+       } else if c.MaxAge > 0 {
+               c.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
+       } else {
+               if c.Expires.IsZero() {
+                       c.Expires = endOfTime
+               } else {
+                       if !c.Expires.After(now) {
+                               return true
+                       }
+               }
+       }
+
+       return
+}
+
+func (cc *clientCookieMiddleware) HeadersReceived(ctx context.Context, md 
metadata.MD) {
+       // instead of replicating the logic for processing the Set-Cookie
+       // header, let's just make a fake response and use the built-in
+       // cookie processing. It's very non-trivial
+       cookies := (&http.Response{
+               Header: http.Header{"Set-Cookie": md.Get("set-cookie")},
+       }).Cookies()
+
+       now := time.Now()
+
+       cc.mx.Lock()
+       defer cc.mx.Unlock()
+
+       for _, c := range cookies {
+               id := c.Name + c.Path
+               if processCookieExpire(c, now) {
+                       delete(cc.jar, id)
+                       continue
+               }
+
+               cc.jar[id] = *c
+       }
+}
diff --git a/go/arrow/flight/cookie_middleware_test.go 
b/go/arrow/flight/cookie_middleware_test.go
new file mode 100644
index 0000000000..e48e9e6577
--- /dev/null
+++ b/go/arrow/flight/cookie_middleware_test.go
@@ -0,0 +1,241 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package flight_test
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "io"
+       "net/http"
+       "net/textproto"
+       "reflect"
+       "strings"
+       "testing"
+       "time"
+
+       "github.com/apache/arrow/go/v13/arrow/flight"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/require"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/credentials/insecure"
+       "google.golang.org/grpc/metadata"
+)
+
+// strings.Cut is go1.18+ so let's just stick a duplicate of it in here
+// for now since we want to support go1.17
+func cut(s, sep string) (before, after string, found bool) {
+       if i := strings.Index(s, sep); i >= 0 {
+               return s[:i], s[i+len(sep):], true
+       }
+       return s, "", false
+}
+
+type serverAddCookieMiddleware struct {
+       expectedCookies map[string]string
+
+       cookies []*http.Cookie
+}
+
+func (s *serverAddCookieMiddleware) StartCall(ctx context.Context) 
context.Context {
+       if s.expectedCookies == nil {
+               md := make(metadata.MD)
+               for _, c := range s.cookies {
+                       md.Append("Set-Cookie", c.String())
+               }
+               grpc.SetHeader(ctx, md)
+               return nil
+       }
+
+       cookies := metadata.ValueFromIncomingContext(ctx, "cookie")
+
+       got := make(map[string]string)
+       for _, line := range cookies {
+               line = textproto.TrimString(line)
+
+               var part string
+               for len(line) > 0 {
+                       part, line, _ = cut(line, ";")
+                       part = textproto.TrimString(part)
+                       if part == "" {
+                               continue
+                       }
+
+                       name, val, _ := cut(part, "=")
+                       name = textproto.TrimString(name)
+                       if len(val) > 1 && val[0] == '"' && val[len(val)-1] == 
'"' {
+                               val = val[1 : len(val)-1]
+                       }
+
+                       got[name] = val
+               }
+       }
+
+       if !reflect.DeepEqual(s.expectedCookies, got) {
+               panic(fmt.Sprintf("did not get expected cookies, expected %+v, 
got %+v", s.expectedCookies, got))
+       }
+
+       return nil
+}
+
+func (s *serverAddCookieMiddleware) CallCompleted(ctx context.Context, err 
error) {}
+
+func TestClientCookieMiddleware(t *testing.T) {
+       cookieMiddleware := &serverAddCookieMiddleware{}
+
+       s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+               flight.CreateServerMiddleware(cookieMiddleware),
+       })
+       s.Init("localhost:0")
+       f := &flightServer{}
+       s.RegisterFlightService(f)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+
+       tests := []struct {
+               testname string
+               cookies  []*http.Cookie
+               expected map[string]string
+       }{
+               {"single cookie", []*http.Cookie{{Name: "Cookie-1", Value: 
"v$1", Raw: "Cookie-1=v$1"}},
+                       map[string]string{"Cookie-1": "v$1"}},
+               {"expired", []*http.Cookie{{
+                       Name: "NID", Value: "99=YsDT5", Expires: 
time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC),
+                       RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", Raw: 
"NID=99=YsDT5; expires=Wed, 23-Nov-11 01:05:03 GMT"}},
+                       map[string]string{}},
+               {"multiple", []*http.Cookie{
+                       {Name: "negative maxage", Value: "foobar", MaxAge: -1},
+                       {Name: "special-1", Value: " z"},
+                       {Name: "cookie-2", Value: "v$2"},
+               },
+                       map[string]string{"special-1": " z", "cookie-2": 
"v$2"}},
+       }
+
+       makeReq := func(c flight.Client, t *testing.T) {
+               flightStream, err := c.ListFlights(context.Background(), 
&flight.Criteria{})
+               assert.NoError(t, err)
+
+               for {
+                       _, err := flightStream.Recv()
+                       if err != nil {
+                               if errors.Is(err, io.EOF) {
+                                       break
+                               }
+                               assert.NoError(t, err)
+                       }
+               }
+       }
+
+       for _, tt := range tests {
+               t.Run(tt.testname, func(t *testing.T) {
+                       cookieMiddleware.expectedCookies = nil
+
+                       client, err := 
flight.NewClientWithMiddleware(s.Addr().String(), nil,
+                               
[]flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, credsOpt)
+                       require.NoError(t, err)
+                       defer client.Close()
+
+                       cookieMiddleware.cookies = tt.cookies
+                       makeReq(client, t)
+
+                       cookieMiddleware.expectedCookies = tt.expected
+                       makeReq(client, t)
+               })
+       }
+}
+
+func TestCookieExpiration(t *testing.T) {
+       cookieMiddleware := &serverAddCookieMiddleware{}
+
+       s := flight.NewServerWithMiddleware([]flight.ServerMiddleware{
+               flight.CreateServerMiddleware(cookieMiddleware),
+       })
+       s.Init("localhost:0")
+       f := &flightServer{}
+       s.RegisterFlightService(f)
+
+       go s.Serve()
+       defer s.Shutdown()
+
+       makeReq := func(c flight.Client, t *testing.T) {
+               flightStream, err := c.ListFlights(context.Background(), 
&flight.Criteria{})
+               assert.NoError(t, err)
+
+               for {
+                       _, err := flightStream.Recv()
+                       if err != nil {
+                               if errors.Is(err, io.EOF) {
+                                       break
+                               }
+                               assert.NoError(t, err)
+                       }
+               }
+       }
+
+       credsOpt := grpc.WithTransportCredentials(insecure.NewCredentials())
+       client, err := flight.NewClientWithMiddleware(s.Addr().String(), nil,
+               []flight.ClientMiddleware{flight.NewClientCookieMiddleware()}, 
credsOpt)
+       require.NoError(t, err)
+       defer client.Close()
+
+       // set cookies
+       cookieMiddleware.cookies = []*http.Cookie{
+               {Name: "foo", Value: "bar"},
+               {Name: "foo2", Value: "bar2", MaxAge: 1},
+       }
+       makeReq(client, t)
+
+       // validate set
+       cookieMiddleware.expectedCookies = map[string]string{
+               "foo": "bar", "foo2": "bar2",
+       }
+       makeReq(client, t)
+
+       // wait for foo2 to expire and validate it doesn't get sent
+       time.Sleep(1 * time.Second)
+       cookieMiddleware.expectedCookies = map[string]string{
+               "foo": "bar",
+       }
+       makeReq(client, t)
+
+       // update value
+       cookieMiddleware.cookies = []*http.Cookie{
+               {Name: "foo", Value: "baz"},
+       }
+       cookieMiddleware.expectedCookies = nil
+       makeReq(client, t)
+
+       // validate updated value is sent
+       cookieMiddleware.expectedCookies = map[string]string{
+               "foo": "baz",
+       }
+       makeReq(client, t)
+
+       // force delete cookie
+       cookieMiddleware.expectedCookies = nil
+       cookieMiddleware.cookies = []*http.Cookie{
+               {Name: "foo", MaxAge: -1}, // delete now!
+       }
+       makeReq(client, t)
+
+       // verify it's been deleted
+       cookieMiddleware.expectedCookies = map[string]string{}
+       makeReq(client, t)
+}

Reply via email to