This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 26ef904  THRIFT-4914: Send context THeaders for client writes
26ef904 is described below

commit 26ef904600edc810f6514605c0611b7442a4c64e
Author: Yuxuan 'fishy' Wang <yuxuan.w...@reddit.com>
AuthorDate: Mon Aug 19 00:18:22 2019 -0700

    THRIFT-4914: Send context THeaders for client writes
    
    Client: go
    
    This is the second part of THRIFT-4914, which handles the client writing
    part in the requests (client -> server direction).
    
    In TStandardClient, when the context has write headers set, and the
    protocol is THeaderProtocol, automatically extract all headers from the
    context object and set to THeaderProtocol to send over the wire.
    
    Client code can set headers into the context object by using the helper
    functions in header_context.go.
    
    Note that we have separated keys for read and write header key list, so
    that for code that's both a server and a client (example: a server that
    calls other upstream thrift servers), they don't automatically forward
    all headers to their upstream servers, and need to explicitly set which
    headers to forward.
    
    In order to make auto forwarding easier, also add SetForwardHeaders
    function to TSimpleServer, which will help the users to auto forward
    selected headers.
    
    This closes #1845.
---
 lib/go/thrift/client.go              | 10 ++++++++++
 lib/go/thrift/header_context.go      | 20 ++++++++++++++++++++
 lib/go/thrift/header_context_test.go | 33 ++++++++++++++++++++++++++++++++-
 lib/go/thrift/simple_server.go       | 24 ++++++++++++++++++++++++
 4 files changed, 86 insertions(+), 1 deletion(-)

diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go
index 28791cc..b073a95 100644
--- a/lib/go/thrift/client.go
+++ b/lib/go/thrift/client.go
@@ -24,6 +24,16 @@ func NewTStandardClient(inputProtocol, outputProtocol 
TProtocol) *TStandardClien
 }
 
 func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId 
int32, method string, args TStruct) error {
+       // Set headers from context object on THeaderProtocol
+       if headerProt, ok := oprot.(*THeaderProtocol); ok {
+               headerProt.ClearWriteHeaders()
+               for _, key := range GetWriteHeaderList(ctx) {
+                       if value, ok := GetHeader(ctx, key); ok {
+                               headerProt.SetWriteHeader(key, value)
+                       }
+               }
+       }
+
        if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil {
                return err
        }
diff --git a/lib/go/thrift/header_context.go b/lib/go/thrift/header_context.go
index 5d9104b..21e880d 100644
--- a/lib/go/thrift/header_context.go
+++ b/lib/go/thrift/header_context.go
@@ -32,6 +32,7 @@ type (
 // Values for headerKeyList.
 const (
        headerKeyListRead headerKeyList = iota
+       headerKeyListWrite
 )
 
 // SetHeader sets a header in the context.
@@ -70,6 +71,25 @@ func GetReadHeaderList(ctx context.Context) []string {
        return nil
 }
 
+// SetWriteHeaderList sets the key list of THeaders to write in the context.
+func SetWriteHeaderList(ctx context.Context, keys []string) context.Context {
+       return context.WithValue(
+               ctx,
+               headerKeyListWrite,
+               keys,
+       )
+}
+
+// GetWriteHeaderList returns the key list of THeaders to write from the 
context.
+func GetWriteHeaderList(ctx context.Context) []string {
+       if v := ctx.Value(headerKeyListWrite); v != nil {
+               if value, ok := v.([]string); ok {
+                       return value
+               }
+       }
+       return nil
+}
+
 // AddReadTHeaderToContext adds the whole THeader headers into context.
 func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) 
context.Context {
        keys := make([]string, 0, len(headers))
diff --git a/lib/go/thrift/header_context_test.go 
b/lib/go/thrift/header_context_test.go
index 33ac4ec..a1ea2d0 100644
--- a/lib/go/thrift/header_context_test.go
+++ b/lib/go/thrift/header_context_test.go
@@ -70,7 +70,7 @@ func TestSetGetHeader(t *testing.T) {
        )
 }
 
-func TestKeyList(t *testing.T) {
+func TestReadKeyList(t *testing.T) {
        headers := THeaderMap{
                "key1": "value1",
                "key2": "value2",
@@ -94,4 +94,35 @@ func TestKeyList(t *testing.T) {
        if !reflect.DeepEqual(headers, got) {
                t.Errorf("Expected header map %+v, got %+v", headers, got)
        }
+
+       writtenKeys := GetWriteHeaderList(ctx)
+       if len(writtenKeys) > 0 {
+               t.Errorf(
+                       "Expected empty GetWriteHeaderList() result, got %+v",
+                       writtenKeys,
+               )
+       }
+}
+
+func TestWriteKeyList(t *testing.T) {
+       keys := []string{
+               "key1",
+               "key2",
+       }
+       ctx := context.Background()
+
+       ctx = SetWriteHeaderList(ctx, keys)
+       got := GetWriteHeaderList(ctx)
+
+       if !reflect.DeepEqual(keys, got) {
+               t.Errorf("Expected header keys %+v, got %+v", keys, got)
+       }
+
+       readKeys := GetReadHeaderList(ctx)
+       if len(readKeys) > 0 {
+               t.Errorf(
+                       "Expected empty GetReadHeaderList() result, got %+v",
+                       readKeys,
+               )
+       }
 }
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 9155cfb..f8efbed 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -42,6 +42,9 @@ type TSimpleServer struct {
        outputTransportFactory TTransportFactory
        inputProtocolFactory   TProtocolFactory
        outputProtocolFactory  TProtocolFactory
+
+       // Headers to auto forward in THeaderProtocol
+       forwardHeaders []string
 }
 
 func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) 
*TSimpleServer {
@@ -125,6 +128,26 @@ func (p *TSimpleServer) Listen() error {
        return p.serverTransport.Listen()
 }
 
+// SetForwardHeaders sets the list of header keys that will be auto forwarded
+// while using THeaderProtocol.
+//
+// "forward" means that when the server is also a client to other upstream
+// thrift servers, the context object user gets in the processor functions will
+// have both read and write headers set, with write headers being forwarded.
+// Users can always override the write headers by calling SetWriteHeaderList
+// before calling thrift client functions.
+func (p *TSimpleServer) SetForwardHeaders(headers []string) {
+       size := len(headers)
+       if size == 0 {
+               p.forwardHeaders = nil
+               return
+       }
+
+       keys := make([]string, size)
+       copy(keys, headers)
+       p.forwardHeaders = keys
+}
+
 func (p *TSimpleServer) innerAccept() (int32, error) {
        client, err := p.serverTransport.Accept()
        p.mu.Lock()
@@ -235,6 +258,7 @@ func (p *TSimpleServer) processRequests(client TTransport) 
error {
                                return err
                        }
                        ctx = AddReadTHeaderToContext(defaultCtx, 
headerProtocol.GetReadHeaders())
+                       ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
                }
 
                ok, err := processor.Process(ctx, inputProtocol, outputProtocol)

Reply via email to