This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new a2c4466  THRIFT-5278: Allow set protoID in go THeader 
transport/protocol
a2c4466 is described below

commit a2c44665b416522477cffa6752c2f323768d0507
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Mon Sep 21 12:33:26 2020 -0700

    THRIFT-5278: Allow set protoID in go THeader transport/protocol
    
    Client: go
    
    In Go library code, allow setting the underlying protoID to a
    non-default (TCompactProtocol) one for THeaderTransport/THeaderProtocol.
---
 lib/go/thrift/header_protocol.go       | 60 +++++++++++++++++++++++++++++-----
 lib/go/thrift/header_protocol_test.go  | 18 +++++++++-
 lib/go/thrift/header_transport.go      | 44 +++++++++++++++++++++++--
 lib/go/thrift/header_transport_test.go | 23 ++++++++++---
 4 files changed, 128 insertions(+), 17 deletions(-)

diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 428b261..f86d558 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -37,31 +37,73 @@ type THeaderProtocol struct {
 }
 
 // NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport. The passed in transport will be wrapped with THeaderTransport.
+// transport with default protocol ID.
+//
+// The passed in transport will be wrapped with THeaderTransport.
 //
 // Note that THeaderTransport handles frame and zlib by itself,
 // so the underlying transport should be a raw socket transports (TSocket or 
TSSLSocket),
 // instead of rich transports like TZlibTransport or TFramedTransport.
 func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
-       t := NewTHeaderTransport(trans)
-       p, _ := THeaderProtocolDefault.GetProtocol(t)
+       p, err := newTHeaderProtocolWithProtocolID(trans, 
THeaderProtocolDefault)
+       if err != nil {
+               // Since we used THeaderProtocolDefault this should never 
happen,
+               // but put a sanity check here just in case.
+               panic(err)
+       }
+       return p
+}
+
+func newTHeaderProtocolWithProtocolID(trans TTransport, protoID 
THeaderProtocolID) (*THeaderProtocol, error) {
+       t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+       if err != nil {
+               return nil, err
+       }
+       p, err := t.protocolID.GetProtocol(t)
+       if err != nil {
+               return nil, err
+       }
        return &THeaderProtocol{
                transport: t,
                protocol:  p,
-       }
+       }, nil
 }
 
-type tHeaderProtocolFactory struct{}
+type tHeaderProtocolFactory struct {
+       protoID THeaderProtocolID
+}
 
-func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
-       return NewTHeaderProtocol(trans)
+func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
+       p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
+       if err != nil {
+               // Currently there's no way for external users to construct a
+               // valid factory with invalid protoID, so this should never
+               // happen. But put a sanity check here just in case in the
+               // future a bug made that possible.
+               panic(err)
+       }
+       return p
 }
 
-// NewTHeaderProtocolFactory creates a factory for THeader.
+// NewTHeaderProtocolFactory creates a factory for THeader with default 
protocol
+// ID.
 //
 // It's a wrapper for NewTHeaderProtocol
 func NewTHeaderProtocolFactory() TProtocolFactory {
-       return tHeaderProtocolFactory{}
+       return tHeaderProtocolFactory{
+               protoID: THeaderProtocolDefault,
+       }
+}
+
+// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
+// given protocol ID.
+func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) 
(TProtocolFactory, error) {
+       if err := protoID.Validate(); err != nil {
+               return nil, err
+       }
+       return tHeaderProtocolFactory{
+               protoID: protoID,
+       }, nil
 }
 
 // Transport returns the underlying transport.
diff --git a/lib/go/thrift/header_protocol_test.go 
b/lib/go/thrift/header_protocol_test.go
index 9b6019b..f66ea64 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -24,5 +24,21 @@ import (
 )
 
 func TestReadWriteHeaderProtocol(t *testing.T) {
-       ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+       t.Run(
+               "default",
+               func(t *testing.T) {
+                       ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+               },
+       )
+
+       t.Run(
+               "compact",
+               func(t *testing.T) {
+                       f, err := 
NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       ReadWriteProtocolTest(t, f)
+               },
+       )
 }
diff --git a/lib/go/thrift/header_transport.go 
b/lib/go/thrift/header_transport.go
index e208034..562d02f 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -75,6 +75,15 @@ const (
        THeaderProtocolDefault                   = THeaderProtocolBinary
 )
 
+// Declared globally to avoid repetitive allocations, not really used.
+var globalMemoryBuffer = NewTMemoryBuffer()
+
+// Validate checks whether the THeaderProtocolID is a valid/supported one.
+func (id THeaderProtocolID) Validate() error {
+       _, err := id.GetProtocol(globalMemoryBuffer)
+       return err
+}
+
 // GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
 func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
        switch id {
@@ -84,7 +93,7 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) 
(TProtocol, error) {
                        fmt.Sprintf("THeader protocol id %d not supported", id),
                )
        case THeaderProtocolBinary:
-               return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), 
nil
+               return NewTBinaryProtocolTransport(trans), nil
        case THeaderProtocolCompact:
                return NewTCompactProtocol(trans), nil
        }
@@ -93,11 +102,12 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) 
(TProtocol, error) {
 // THeaderTransformID defines the numeric id of the transform used.
 type THeaderTransformID int32
 
-// THeaderTransformID values
+// THeaderTransformID values.
+//
+// Values not defined here are not currently supported, namely HMAC and Snappy.
 const (
        TransformNone THeaderTransformID = iota // 0, no special handling
        TransformZlib                           // 1, zlib
-       // Rest of the values are not currently supported, namely HMAC and 
Snappy.
 )
 
 var supportedTransformIDs = map[THeaderTransformID]bool{
@@ -285,6 +295,34 @@ func NewTHeaderTransport(trans TTransport) 
*THeaderTransport {
        }
 }
 
+// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
+// underlying transport, with given protocol ID set.
+//
+// If trans is already a *THeaderTransport, it will be returned as is,
+// but with protocol ID overridden by the value passed in.
+//
+// If the passed in protocol ID is an invalid/unsupported one,
+// this function returns error.
+//
+// The protocol ID overridden is only useful for client transports.
+// For servers,
+// the protocol ID will be overridden again to the one set by the client,
+// to ensure that servers always speak the same dialect as the client.
+func NewTHeaderTransportWithProtocolID(trans TTransport, protoID 
THeaderProtocolID) (*THeaderTransport, error) {
+       if err := protoID.Validate(); err != nil {
+               return nil, err
+       }
+       if ht, ok := trans.(*THeaderTransport); ok {
+               return ht, nil
+       }
+       return &THeaderTransport{
+               transport:    trans,
+               reader:       bufio.NewReader(trans),
+               writeHeaders: make(THeaderMap),
+               protocolID:   protoID,
+       }, nil
+}
+
 // Open calls the underlying transport's Open function.
 func (t *THeaderTransport) Open() error {
        return t.transport.Open()
diff --git a/lib/go/thrift/header_transport_test.go 
b/lib/go/thrift/header_transport_test.go
index 320fb2a..5b47680 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -28,10 +28,13 @@ import (
        "testing/quick"
 )
 
-func TestTHeaderHeadersReadWrite(t *testing.T) {
+func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID 
THeaderProtocolID) {
        trans := NewTMemoryBuffer()
        reader := NewTHeaderTransport(trans)
-       writer := NewTHeaderTransport(trans)
+       writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+       if err != nil {
+               t.Fatal(err)
+       }
 
        const key1 = "key1"
        const value1 = "value1"
@@ -98,10 +101,10 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
                        read,
                )
        }
-       if prot := reader.Protocol(); prot != THeaderProtocolBinary {
+       if prot := reader.Protocol(); prot != protoID {
                t.Errorf(
                        "reader.Protocol() expected %d, got %d",
-                       THeaderProtocolBinary,
+                       protoID,
                        prot,
                )
        }
@@ -121,6 +124,18 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
        }
 }
 
+func TestTHeaderHeadersReadWrite(t *testing.T) {
+       for label, id := range map[string]THeaderProtocolID{
+               "default": THeaderProtocolDefault,
+               "binary":  THeaderProtocolBinary,
+               "compact": THeaderProtocolCompact,
+       } {
+               t.Run(label, func(t *testing.T) {
+                       testTHeaderHeadersReadWriteProtocolID(t, id)
+               })
+       }
+}
+
 func TestTHeaderTransportNoDoubleWrapping(t *testing.T) {
        trans := NewTMemoryBuffer()
        orig := NewTHeaderTransport(trans)

Reply via email to