This is an automated email from the ASF dual-hosted git repository. dcelasun pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/thrift.git
The following commit(s) were added to refs/heads/master by this push: new 26ef904 THRIFT-4914: Send context THeaders for client writes 26ef904 is described below commit 26ef904600edc810f6514605c0611b7442a4c64e Author: Yuxuan 'fishy' Wang <yuxuan.w...@reddit.com> AuthorDate: Mon Aug 19 00:18:22 2019 -0700 THRIFT-4914: Send context THeaders for client writes Client: go This is the second part of THRIFT-4914, which handles the client writing part in the requests (client -> server direction). In TStandardClient, when the context has write headers set, and the protocol is THeaderProtocol, automatically extract all headers from the context object and set to THeaderProtocol to send over the wire. Client code can set headers into the context object by using the helper functions in header_context.go. Note that we have separated keys for read and write header key list, so that for code that's both a server and a client (example: a server that calls other upstream thrift servers), they don't automatically forward all headers to their upstream servers, and need to explicitly set which headers to forward. In order to make auto forwarding easier, also add SetForwardHeaders function to TSimpleServer, which will help the users to auto forward selected headers. This closes #1845. --- lib/go/thrift/client.go | 10 ++++++++++ lib/go/thrift/header_context.go | 20 ++++++++++++++++++++ lib/go/thrift/header_context_test.go | 33 ++++++++++++++++++++++++++++++++- lib/go/thrift/simple_server.go | 24 ++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) diff --git a/lib/go/thrift/client.go b/lib/go/thrift/client.go index 28791cc..b073a95 100644 --- a/lib/go/thrift/client.go +++ b/lib/go/thrift/client.go @@ -24,6 +24,16 @@ func NewTStandardClient(inputProtocol, outputProtocol TProtocol) *TStandardClien } func (p *TStandardClient) Send(ctx context.Context, oprot TProtocol, seqId int32, method string, args TStruct) error { + // Set headers from context object on THeaderProtocol + if headerProt, ok := oprot.(*THeaderProtocol); ok { + headerProt.ClearWriteHeaders() + for _, key := range GetWriteHeaderList(ctx) { + if value, ok := GetHeader(ctx, key); ok { + headerProt.SetWriteHeader(key, value) + } + } + } + if err := oprot.WriteMessageBegin(method, CALL, seqId); err != nil { return err } diff --git a/lib/go/thrift/header_context.go b/lib/go/thrift/header_context.go index 5d9104b..21e880d 100644 --- a/lib/go/thrift/header_context.go +++ b/lib/go/thrift/header_context.go @@ -32,6 +32,7 @@ type ( // Values for headerKeyList. const ( headerKeyListRead headerKeyList = iota + headerKeyListWrite ) // SetHeader sets a header in the context. @@ -70,6 +71,25 @@ func GetReadHeaderList(ctx context.Context) []string { return nil } +// SetWriteHeaderList sets the key list of THeaders to write in the context. +func SetWriteHeaderList(ctx context.Context, keys []string) context.Context { + return context.WithValue( + ctx, + headerKeyListWrite, + keys, + ) +} + +// GetWriteHeaderList returns the key list of THeaders to write from the context. +func GetWriteHeaderList(ctx context.Context) []string { + if v := ctx.Value(headerKeyListWrite); v != nil { + if value, ok := v.([]string); ok { + return value + } + } + return nil +} + // AddReadTHeaderToContext adds the whole THeader headers into context. func AddReadTHeaderToContext(ctx context.Context, headers THeaderMap) context.Context { keys := make([]string, 0, len(headers)) diff --git a/lib/go/thrift/header_context_test.go b/lib/go/thrift/header_context_test.go index 33ac4ec..a1ea2d0 100644 --- a/lib/go/thrift/header_context_test.go +++ b/lib/go/thrift/header_context_test.go @@ -70,7 +70,7 @@ func TestSetGetHeader(t *testing.T) { ) } -func TestKeyList(t *testing.T) { +func TestReadKeyList(t *testing.T) { headers := THeaderMap{ "key1": "value1", "key2": "value2", @@ -94,4 +94,35 @@ func TestKeyList(t *testing.T) { if !reflect.DeepEqual(headers, got) { t.Errorf("Expected header map %+v, got %+v", headers, got) } + + writtenKeys := GetWriteHeaderList(ctx) + if len(writtenKeys) > 0 { + t.Errorf( + "Expected empty GetWriteHeaderList() result, got %+v", + writtenKeys, + ) + } +} + +func TestWriteKeyList(t *testing.T) { + keys := []string{ + "key1", + "key2", + } + ctx := context.Background() + + ctx = SetWriteHeaderList(ctx, keys) + got := GetWriteHeaderList(ctx) + + if !reflect.DeepEqual(keys, got) { + t.Errorf("Expected header keys %+v, got %+v", keys, got) + } + + readKeys := GetReadHeaderList(ctx) + if len(readKeys) > 0 { + t.Errorf( + "Expected empty GetReadHeaderList() result, got %+v", + readKeys, + ) + } } diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go index 9155cfb..f8efbed 100644 --- a/lib/go/thrift/simple_server.go +++ b/lib/go/thrift/simple_server.go @@ -42,6 +42,9 @@ type TSimpleServer struct { outputTransportFactory TTransportFactory inputProtocolFactory TProtocolFactory outputProtocolFactory TProtocolFactory + + // Headers to auto forward in THeaderProtocol + forwardHeaders []string } func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer { @@ -125,6 +128,26 @@ func (p *TSimpleServer) Listen() error { return p.serverTransport.Listen() } +// SetForwardHeaders sets the list of header keys that will be auto forwarded +// while using THeaderProtocol. +// +// "forward" means that when the server is also a client to other upstream +// thrift servers, the context object user gets in the processor functions will +// have both read and write headers set, with write headers being forwarded. +// Users can always override the write headers by calling SetWriteHeaderList +// before calling thrift client functions. +func (p *TSimpleServer) SetForwardHeaders(headers []string) { + size := len(headers) + if size == 0 { + p.forwardHeaders = nil + return + } + + keys := make([]string, size) + copy(keys, headers) + p.forwardHeaders = keys +} + func (p *TSimpleServer) innerAccept() (int32, error) { client, err := p.serverTransport.Accept() p.mu.Lock() @@ -235,6 +258,7 @@ func (p *TSimpleServer) processRequests(client TTransport) error { return err } ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders()) + ctx = SetWriteHeaderList(ctx, p.forwardHeaders) } ok, err := processor.Process(ctx, inputProtocol, outputProtocol)