This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 4d46c11  THRIFT-4612: THeader support in go library
4d46c11 is described below

commit 4d46c1124450eeb77d2a6adc7ea5fab304bfeb4a
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Fri Jun 7 20:47:18 2019 +0800

    THRIFT-4612: THeader support in go library
    
    Client: go
    
    Implement THeaderTransport and THeaderProtocol, with support of:
    
    * clients:
      - headers
      - framedBinary
      - unframedBinary
      - framedCompact
      - unframedCompact
    * transforms:
      - none
      - zlib
    * info types:
      - key_value
    * wrapped protocols:
      - TBinary
      - TCompact
    
    The support list is in general on par of the THeader implementation in
    the python library.
    
    The cross-test passes, except ones related to cpp/nodejs http transport,
    which were also failing for non-theader protocols.
    
    This change also fixes two bugs:
    
    1. A small issue in test/go/src/bin/testserver/main.go
    2. A bug in TFrameTransport go implementation
---
 lib/go/thrift/application_exception.go |   6 +
 lib/go/thrift/framed_transport.go      |  16 +-
 lib/go/thrift/header_protocol.go       | 300 ++++++++++++++
 lib/go/thrift/header_protocol_test.go  |  28 ++
 lib/go/thrift/header_transport.go      | 692 +++++++++++++++++++++++++++++++++
 lib/go/thrift/header_transport_test.go | 108 +++++
 lib/go/thrift/simple_server.go         |  22 +-
 test/go/src/bin/testserver/main.go     |   4 +-
 test/go/src/common/client.go           |   2 +
 test/go/src/common/server.go           |   2 +
 test/known_failures_Linux.json         |   8 +
 test/tests.json                        |   3 +-
 12 files changed, 1182 insertions(+), 9 deletions(-)

diff --git a/lib/go/thrift/application_exception.go 
b/lib/go/thrift/application_exception.go
index b9d7eed..0023c57 100644
--- a/lib/go/thrift/application_exception.go
+++ b/lib/go/thrift/application_exception.go
@@ -28,6 +28,9 @@ const (
        MISSING_RESULT                 = 5
        INTERNAL_ERROR                 = 6
        PROTOCOL_ERROR                 = 7
+       INVALID_TRANSFORM              = 8
+       INVALID_PROTOCOL               = 9
+       UNSUPPORTED_CLIENT_TYPE        = 10
 )
 
 var defaultApplicationExceptionMessage = map[int32]string{
@@ -39,6 +42,9 @@ var defaultApplicationExceptionMessage = map[int32]string{
        MISSING_RESULT:                 "missing result",
        INTERNAL_ERROR:                 "unknown internal error",
        PROTOCOL_ERROR:                 "unknown protocol error",
+       INVALID_TRANSFORM:              "Invalid transform",
+       INVALID_PROTOCOL:               "Invalid protocol",
+       UNSUPPORTED_CLIENT_TYPE:        "Unsupported client type",
 }
 
 // Application level Thrift exception
diff --git a/lib/go/thrift/framed_transport.go 
b/lib/go/thrift/framed_transport.go
index 81fa65a..34275b5 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -93,7 +93,21 @@ func (p *TFramedTransport) Read(buf []byte) (l int, err 
error) {
                l, err = p.Read(tmp)
                copy(buf, tmp)
                if err == nil {
-                       err = NewTTransportExceptionFromError(fmt.Errorf("Not 
enough frame size %d to read %d bytes", frameSize, len(buf)))
+                       // Note: It's important to only return an error when l
+                       // is zero.
+                       // In io.Reader.Read interface, it's perfectly fine to
+                       // return partial data and nil error, which means
+                       // "This is all the data we have right now without
+                       // blocking. If you need the full data, call Read again
+                       // or use io.ReadFull instead".
+                       // Returning partial data with an error actually means
+                       // there's no more data after the partial data just
+                       // returned, which is not true in this case
+                       // (it might be that the other end just haven't written
+                       // them yet).
+                       if l == 0 {
+                               err = 
NewTTransportExceptionFromError(fmt.Errorf("Not enough frame size %d to read %d 
bytes", frameSize, len(buf)))
+                       }
                        return
                }
        }
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
new file mode 100644
index 0000000..0cf48f7
--- /dev/null
+++ b/lib/go/thrift/header_protocol.go
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "context"
+)
+
+// THeaderProtocol is a thrift protocol that implements THeader:
+// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
+//
+// It supports either binary or compact protocol as the wrapped protocol.
+//
+// Most of the THeader handlings are happening inside THeaderTransport.
+type THeaderProtocol struct {
+       transport *THeaderTransport
+
+       // Will be initialized on first read/write.
+       protocol TProtocol
+}
+
+// NewTHeaderProtocol creates a new THeaderProtocol from the underlying
+// transport. The passed in transport will be wrapped with THeaderTransport.
+//
+// Note that THeaderTransport handles frame and zlib by itself,
+// so the underlying transport should be a raw socket transports (TSocket or 
TSSLSocket),
+// instead of rich transports like TZlibTransport or TFramedTransport.
+func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
+       t := NewTHeaderTransport(trans)
+       p, _ := THeaderProtocolDefault.GetProtocol(t)
+       return &THeaderProtocol{
+               transport: t,
+               protocol:  p,
+       }
+}
+
+type tHeaderProtocolFactory struct{}
+
+func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
+       return NewTHeaderProtocol(trans)
+}
+
+// NewTHeaderProtocolFactory creates a factory for THeader.
+//
+// It's a wrapper for NewTHeaderProtocol
+func NewTHeaderProtocolFactory() TProtocolFactory {
+       return tHeaderProtocolFactory{}
+}
+
+// Transport returns the underlying transport.
+//
+// It's guaranteed to be of type *THeaderTransport.
+func (p *THeaderProtocol) Transport() TTransport {
+       return p.transport
+}
+
+// GetReadHeaders returns the THeaderMap read from transport.
+func (p *THeaderProtocol) GetReadHeaders() THeaderMap {
+       return p.transport.GetReadHeaders()
+}
+
+// SetWriteHeader sets a header for write.
+func (p *THeaderProtocol) SetWriteHeader(key, value string) {
+       p.transport.SetWriteHeader(key, value)
+}
+
+// ClearWriteHeaders clears all write headers previously set.
+func (p *THeaderProtocol) ClearWriteHeaders() {
+       p.transport.ClearWriteHeaders()
+}
+
+// AddTransform add a transform for writing.
+func (p *THeaderProtocol) AddTransform(transform THeaderTransformID) error {
+       return p.transport.AddTransform(transform)
+}
+
+func (p *THeaderProtocol) Flush(ctx context.Context) error {
+       return p.transport.Flush(ctx)
+}
+
+func (p *THeaderProtocol) WriteMessageBegin(name string, typeID TMessageType, 
seqID int32) error {
+       newProto, err := p.transport.Protocol().GetProtocol(p.transport)
+       if err != nil {
+               return err
+       }
+       p.protocol = newProto
+       p.transport.SequenceID = seqID
+       return p.protocol.WriteMessageBegin(name, typeID, seqID)
+}
+
+func (p *THeaderProtocol) WriteMessageEnd() error {
+       if err := p.protocol.WriteMessageEnd(); err != nil {
+               return err
+       }
+       return p.transport.Flush(context.Background())
+}
+
+func (p *THeaderProtocol) WriteStructBegin(name string) error {
+       return p.protocol.WriteStructBegin(name)
+}
+
+func (p *THeaderProtocol) WriteStructEnd() error {
+       return p.protocol.WriteStructEnd()
+}
+
+func (p *THeaderProtocol) WriteFieldBegin(name string, typeID TType, id int16) 
error {
+       return p.protocol.WriteFieldBegin(name, typeID, id)
+}
+
+func (p *THeaderProtocol) WriteFieldEnd() error {
+       return p.protocol.WriteFieldEnd()
+}
+
+func (p *THeaderProtocol) WriteFieldStop() error {
+       return p.protocol.WriteFieldStop()
+}
+
+func (p *THeaderProtocol) WriteMapBegin(keyType TType, valueType TType, size 
int) error {
+       return p.protocol.WriteMapBegin(keyType, valueType, size)
+}
+
+func (p *THeaderProtocol) WriteMapEnd() error {
+       return p.protocol.WriteMapEnd()
+}
+
+func (p *THeaderProtocol) WriteListBegin(elemType TType, size int) error {
+       return p.protocol.WriteListBegin(elemType, size)
+}
+
+func (p *THeaderProtocol) WriteListEnd() error {
+       return p.protocol.WriteListEnd()
+}
+
+func (p *THeaderProtocol) WriteSetBegin(elemType TType, size int) error {
+       return p.protocol.WriteSetBegin(elemType, size)
+}
+
+func (p *THeaderProtocol) WriteSetEnd() error {
+       return p.protocol.WriteSetEnd()
+}
+
+func (p *THeaderProtocol) WriteBool(value bool) error {
+       return p.protocol.WriteBool(value)
+}
+
+func (p *THeaderProtocol) WriteByte(value int8) error {
+       return p.protocol.WriteByte(value)
+}
+
+func (p *THeaderProtocol) WriteI16(value int16) error {
+       return p.protocol.WriteI16(value)
+}
+
+func (p *THeaderProtocol) WriteI32(value int32) error {
+       return p.protocol.WriteI32(value)
+}
+
+func (p *THeaderProtocol) WriteI64(value int64) error {
+       return p.protocol.WriteI64(value)
+}
+
+func (p *THeaderProtocol) WriteDouble(value float64) error {
+       return p.protocol.WriteDouble(value)
+}
+
+func (p *THeaderProtocol) WriteString(value string) error {
+       return p.protocol.WriteString(value)
+}
+
+func (p *THeaderProtocol) WriteBinary(value []byte) error {
+       return p.protocol.WriteBinary(value)
+}
+
+func (p *THeaderProtocol) ReadMessageBegin() (name string, typeID 
TMessageType, seqID int32, err error) {
+       if err = p.transport.ReadFrame(); err != nil {
+               return
+       }
+
+       var newProto TProtocol
+       newProto, err = p.transport.Protocol().GetProtocol(p.transport)
+       if err != nil {
+               tAppExc, ok := err.(TApplicationException)
+               if !ok {
+                       return
+               }
+               if e := p.protocol.WriteMessageBegin("", EXCEPTION, seqID); e 
!= nil {
+                       return
+               }
+               if e := tAppExc.Write(p.protocol); e != nil {
+                       return
+               }
+               if e := p.protocol.WriteMessageEnd(); e != nil {
+                       return
+               }
+               if e := p.transport.Flush(context.Background()); e != nil {
+                       return
+               }
+               return
+       }
+       p.protocol = newProto
+
+       return p.protocol.ReadMessageBegin()
+}
+
+func (p *THeaderProtocol) ReadMessageEnd() error {
+       return p.protocol.ReadMessageEnd()
+}
+
+func (p *THeaderProtocol) ReadStructBegin() (name string, err error) {
+       return p.protocol.ReadStructBegin()
+}
+
+func (p *THeaderProtocol) ReadStructEnd() error {
+       return p.protocol.ReadStructEnd()
+}
+
+func (p *THeaderProtocol) ReadFieldBegin() (name string, typeID TType, id 
int16, err error) {
+       return p.protocol.ReadFieldBegin()
+}
+
+func (p *THeaderProtocol) ReadFieldEnd() error {
+       return p.protocol.ReadFieldEnd()
+}
+
+func (p *THeaderProtocol) ReadMapBegin() (keyType TType, valueType TType, size 
int, err error) {
+       return p.protocol.ReadMapBegin()
+}
+
+func (p *THeaderProtocol) ReadMapEnd() error {
+       return p.protocol.ReadMapEnd()
+}
+
+func (p *THeaderProtocol) ReadListBegin() (elemType TType, size int, err 
error) {
+       return p.protocol.ReadListBegin()
+}
+
+func (p *THeaderProtocol) ReadListEnd() error {
+       return p.protocol.ReadListEnd()
+}
+
+func (p *THeaderProtocol) ReadSetBegin() (elemType TType, size int, err error) 
{
+       return p.protocol.ReadSetBegin()
+}
+
+func (p *THeaderProtocol) ReadSetEnd() error {
+       return p.protocol.ReadSetEnd()
+}
+
+func (p *THeaderProtocol) ReadBool() (value bool, err error) {
+       return p.protocol.ReadBool()
+}
+
+func (p *THeaderProtocol) ReadByte() (value int8, err error) {
+       return p.protocol.ReadByte()
+}
+
+func (p *THeaderProtocol) ReadI16() (value int16, err error) {
+       return p.protocol.ReadI16()
+}
+
+func (p *THeaderProtocol) ReadI32() (value int32, err error) {
+       return p.protocol.ReadI32()
+}
+
+func (p *THeaderProtocol) ReadI64() (value int64, err error) {
+       return p.protocol.ReadI64()
+}
+
+func (p *THeaderProtocol) ReadDouble() (value float64, err error) {
+       return p.protocol.ReadDouble()
+}
+
+func (p *THeaderProtocol) ReadString() (value string, err error) {
+       return p.protocol.ReadString()
+}
+
+func (p *THeaderProtocol) ReadBinary() (value []byte, err error) {
+       return p.protocol.ReadBinary()
+}
+
+func (p *THeaderProtocol) Skip(fieldType TType) error {
+       return p.protocol.Skip(fieldType)
+}
diff --git a/lib/go/thrift/header_protocol_test.go 
b/lib/go/thrift/header_protocol_test.go
new file mode 100644
index 0000000..9b6019b
--- /dev/null
+++ b/lib/go/thrift/header_protocol_test.go
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "testing"
+)
+
+func TestReadWriteHeaderProtocol(t *testing.T) {
+       ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+}
diff --git a/lib/go/thrift/header_transport.go 
b/lib/go/thrift/header_transport.go
new file mode 100644
index 0000000..3e68460
--- /dev/null
+++ b/lib/go/thrift/header_transport.go
@@ -0,0 +1,692 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "bufio"
+       "bytes"
+       "compress/zlib"
+       "context"
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "io"
+       "io/ioutil"
+)
+
+// Size in bytes for 32-bit ints.
+const size32 = 4
+
+type headerMeta struct {
+       MagicFlags   uint32
+       SequenceID   int32
+       HeaderLength uint16
+}
+
+const headerMetaSize = 10
+
+type clientType int
+
+const (
+       clientUnknown clientType = iota
+       clientHeaders
+       clientFramedBinary
+       clientUnframedBinary
+       clientFramedCompact
+       clientUnframedCompact
+)
+
+// Constants defined in THeader format:
+// https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md
+const (
+       THeaderHeaderMagic  uint32 = 0x0fff0000
+       THeaderHeaderMask   uint32 = 0xffff0000
+       THeaderFlagsMask    uint32 = 0x0000ffff
+       THeaderMaxFrameSize uint32 = 0x3fffffff
+)
+
+// THeaderMap is the type of the header map in THeader transport.
+type THeaderMap map[string]string
+
+// THeaderProtocolID is the wrapped protocol id used in THeader.
+type THeaderProtocolID int32
+
+// Supported THeaderProtocolID values.
+const (
+       THeaderProtocolBinary  THeaderProtocolID = 0x00
+       THeaderProtocolCompact THeaderProtocolID = 0x02
+       THeaderProtocolDefault                   = THeaderProtocolBinary
+)
+
+// GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
+func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
+       switch id {
+       default:
+               return nil, NewTApplicationException(
+                       INVALID_PROTOCOL,
+                       fmt.Sprintf("THeader protocol id %d not supported", id),
+               )
+       case THeaderProtocolBinary:
+               return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), 
nil
+       case THeaderProtocolCompact:
+               return NewTCompactProtocol(trans), nil
+       }
+}
+
+// THeaderTransformID defines the numeric id of the transform used.
+type THeaderTransformID int32
+
+// THeaderTransformID values
+const (
+       TransformNone THeaderTransformID = iota // 0, no special handling
+       TransformZlib                           // 1, zlib
+       // Rest of the values are not currently supported, namely HMAC and 
Snappy.
+)
+
+var supportedTransformIDs = map[THeaderTransformID]bool{
+       TransformNone: true,
+       TransformZlib: true,
+}
+
+// TransformReader is an io.ReadCloser that handles transforms reading.
+type TransformReader struct {
+       io.Reader
+
+       closers []io.Closer
+}
+
+var _ io.ReadCloser = (*TransformReader)(nil)
+
+// NewTransformReaderWithCapacity initializes a TransformReader with expected
+// closers capacity.
+//
+// If you don't know the closers capacity beforehand, just use
+//
+//     &TransformReader{Reader: baseReader}
+//
+// instead would be sufficient.
+func NewTransformReaderWithCapacity(baseReader io.Reader, capacity int) 
*TransformReader {
+       return &TransformReader{
+               Reader:  baseReader,
+               closers: make([]io.Closer, 0, capacity),
+       }
+}
+
+// Close calls the underlying closers in appropriate order,
+// stops at and returns the first error encountered.
+func (tr *TransformReader) Close() error {
+       // Call closers in reversed order
+       for i := len(tr.closers) - 1; i >= 0; i-- {
+               if err := tr.closers[i].Close(); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+// AddTransform adds a transform.
+func (tr *TransformReader) AddTransform(id THeaderTransformID) error {
+       switch id {
+       default:
+               return NewTApplicationException(
+                       INVALID_TRANSFORM,
+                       fmt.Sprintf("THeaderTransformID %d not supported", id),
+               )
+       case TransformNone:
+               // no-op
+       case TransformZlib:
+               readCloser, err := zlib.NewReader(tr.Reader)
+               if err != nil {
+                       return err
+               }
+               tr.Reader = readCloser
+               tr.closers = append(tr.closers, readCloser)
+       }
+       return nil
+}
+
+// TransformWriter is an io.WriteCloser that handles transforms writing.
+type TransformWriter struct {
+       io.Writer
+
+       closers []io.Closer
+}
+
+var _ io.WriteCloser = (*TransformWriter)(nil)
+
+// NewTransformWriter creates a new TransformWriter with base writer and 
transforms.
+func NewTransformWriter(baseWriter io.Writer, transforms []THeaderTransformID) 
(io.WriteCloser, error) {
+       writer := &TransformWriter{
+               Writer:  baseWriter,
+               closers: make([]io.Closer, 0, len(transforms)),
+       }
+       for _, id := range transforms {
+               if err := writer.AddTransform(id); err != nil {
+                       return nil, err
+               }
+       }
+       return writer, nil
+}
+
+// Close calls the underlying closers in appropriate order,
+// stops at and returns the first error encountered.
+func (tw *TransformWriter) Close() error {
+       // Call closers in reversed order
+       for i := len(tw.closers) - 1; i >= 0; i-- {
+               if err := tw.closers[i].Close(); err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+// AddTransform adds a transform.
+func (tw *TransformWriter) AddTransform(id THeaderTransformID) error {
+       switch id {
+       default:
+               return NewTApplicationException(
+                       INVALID_TRANSFORM,
+                       fmt.Sprintf("THeaderTransformID %d not supported", id),
+               )
+       case TransformNone:
+               // no-op
+       case TransformZlib:
+               writeCloser := zlib.NewWriter(tw.Writer)
+               tw.Writer = writeCloser
+               tw.closers = append(tw.closers, writeCloser)
+       }
+       return nil
+}
+
+// THeaderInfoType is the type id of the info headers.
+type THeaderInfoType int32
+
+// Supported THeaderInfoType values.
+const (
+       _            THeaderInfoType = iota // Skip 0
+       InfoKeyValue                        // 1
+       // Rest of the info types are not supported.
+)
+
+// THeaderTransport is a Transport mode that implements THeader.
+//
+// Note that THeaderTransport handles frame and zlib by itself,
+// so the underlying transport should be a raw socket transports (TSocket or 
TSSLSocket),
+// instead of rich transports like TZlibTransport or TFramedTransport.
+type THeaderTransport struct {
+       SequenceID int32
+       Flags      uint32
+
+       transport TTransport
+
+       // THeaderMap for read and write
+       readHeaders  THeaderMap
+       writeHeaders THeaderMap
+
+       // Reading related variables.
+       reader *bufio.Reader
+       // When frame is detected, we read the frame fully into frameBuffer.
+       frameBuffer bytes.Buffer
+       // When it's non-nil, Read should read from frameReader instead of
+       // reader, and EOF error indicates end of frame instead of end of all
+       // transport.
+       frameReader io.ReadCloser
+
+       // Writing related variables
+       writeBuffer     bytes.Buffer
+       writeTransforms []THeaderTransformID
+
+       clientType clientType
+       protocolID THeaderProtocolID
+
+       // buffer is used in the following scenarios to avoid repetitive
+       // allocations, while 4 is big enough for all those scenarios:
+       //
+       // * header padding (max size 4)
+       // * write the frame size (size 4)
+       buffer [4]byte
+}
+
+var _ TTransport = (*THeaderTransport)(nil)
+
+// NewTHeaderTransport creates THeaderTransport from the underlying transport.
+//
+// Please note that THeaderTransport handles framing and zlib by itself,
+// so the underlying transport should be the raw socket transports (TSocket or 
TSSLSocket),
+// instead of rich transports like TZlibTransport or TFramedTransport.
+func NewTHeaderTransport(trans TTransport) *THeaderTransport {
+       return &THeaderTransport{
+               transport:    trans,
+               reader:       bufio.NewReader(trans),
+               writeHeaders: make(THeaderMap),
+               protocolID:   THeaderProtocolDefault,
+       }
+}
+
+// Open calls the underlying transport's Open function.
+func (t *THeaderTransport) Open() error {
+       return t.transport.Open()
+}
+
+// IsOpen calls the underlying transport's IsOpen function.
+func (t *THeaderTransport) IsOpen() bool {
+       return t.transport.IsOpen()
+}
+
+// ReadFrame tries to read the frame header, guess the client type, and handle
+// unframed clients.
+func (t *THeaderTransport) ReadFrame() error {
+       if !t.needReadFrame() {
+               // No need to read frame, skipping.
+               return nil
+       }
+       // Peek and handle the first 32 bits.
+       // They could either be the length field of a framed message,
+       // or the first bytes of an unframed message.
+       buf, err := t.reader.Peek(size32)
+       if err != nil {
+               return err
+       }
+       frameSize := binary.BigEndian.Uint32(buf)
+       if frameSize&VERSION_MASK == VERSION_1 {
+               t.clientType = clientUnframedBinary
+               return nil
+       }
+       if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == 
COMPACT_VERSION {
+               t.clientType = clientUnframedCompact
+               return nil
+       }
+
+       // At this point it should be a framed message,
+       // sanity check on frameSize then discard the peeked part.
+       if frameSize > THeaderMaxFrameSize {
+               return NewTProtocolExceptionWithType(
+                       SIZE_LIMIT,
+                       errors.New("frame too large"),
+               )
+       }
+       t.reader.Discard(size32)
+
+       // Read the frame fully into frameBuffer.
+       _, err = io.Copy(
+               &t.frameBuffer,
+               io.LimitReader(t.reader, int64(frameSize)),
+       )
+       if err != nil {
+               return err
+       }
+       t.frameReader = ioutil.NopCloser(&t.frameBuffer)
+
+       // Peek and handle the next 32 bits.
+       buf = t.frameBuffer.Bytes()[:size32]
+       version := binary.BigEndian.Uint32(buf)
+       if version&THeaderHeaderMask == THeaderHeaderMagic {
+               t.clientType = clientHeaders
+               return t.parseHeaders(frameSize)
+       }
+       if version&VERSION_MASK == VERSION_1 {
+               t.clientType = clientFramedBinary
+               return nil
+       }
+       if buf[0] == COMPACT_PROTOCOL_ID && buf[1]&COMPACT_VERSION_MASK == 
COMPACT_VERSION {
+               t.clientType = clientFramedCompact
+               return nil
+       }
+       if err := t.endOfFrame(); err != nil {
+               return err
+       }
+       return NewTProtocolExceptionWithType(
+               NOT_IMPLEMENTED,
+               errors.New("unsupported client transport type"),
+       )
+}
+
+// endOfFrame does end of frame handling.
+//
+// It closes frameReader, and also resets frame related states.
+func (t *THeaderTransport) endOfFrame() error {
+       defer func() {
+               t.frameBuffer.Reset()
+               t.frameReader = nil
+       }()
+       return t.frameReader.Close()
+}
+
+func (t *THeaderTransport) parseHeaders(frameSize uint32) error {
+       if t.clientType != clientHeaders {
+               return nil
+       }
+
+       var err error
+       var meta headerMeta
+       if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != 
nil {
+               return err
+       }
+       frameSize -= headerMetaSize
+       t.Flags = meta.MagicFlags & THeaderFlagsMask
+       t.SequenceID = meta.SequenceID
+       headerLength := int64(meta.HeaderLength) * 4
+       if int64(frameSize) < headerLength {
+               return NewTProtocolExceptionWithType(
+                       SIZE_LIMIT,
+                       errors.New("header size is larger than the whole 
frame"),
+               )
+       }
+       headerBuf := NewTMemoryBuffer()
+       _, err = io.Copy(headerBuf, io.LimitReader(&t.frameBuffer, 
headerLength))
+       if err != nil {
+               return err
+       }
+       hp := NewTCompactProtocol(headerBuf)
+
+       // At this point the header is already read into headerBuf,
+       // and t.frameBuffer starts from the actual payload.
+       protoID, err := hp.readVarint32()
+       if err != nil {
+               return err
+       }
+       t.protocolID = THeaderProtocolID(protoID)
+       var transformCount int32
+       transformCount, err = hp.readVarint32()
+       if err != nil {
+               return err
+       }
+       if transformCount > 0 {
+               reader := NewTransformReaderWithCapacity(
+                       &t.frameBuffer,
+                       int(transformCount),
+               )
+               t.frameReader = reader
+               transformIDs := make([]THeaderTransformID, transformCount)
+               for i := 0; i < int(transformCount); i++ {
+                       id, err := hp.readVarint32()
+                       if err != nil {
+                               return err
+                       }
+                       transformIDs[i] = THeaderTransformID(id)
+               }
+               // The transform IDs on the wire was added based on the order of
+               // writing, so on the reading side we need to reverse the order.
+               for i := transformCount - 1; i >= 0; i-- {
+                       id := transformIDs[i]
+                       if err := reader.AddTransform(id); err != nil {
+                               return err
+                       }
+               }
+       }
+
+       // The info part does not use the transforms yet, so it's
+       // important to continue using headerBuf.
+       headers := make(THeaderMap)
+       for {
+               infoType, err := hp.readVarint32()
+               if err == io.EOF {
+                       break
+               }
+               if err != nil {
+                       return err
+               }
+               if THeaderInfoType(infoType) == InfoKeyValue {
+                       count, err := hp.readVarint32()
+                       if err != nil {
+                               return err
+                       }
+                       for i := 0; i < int(count); i++ {
+                               key, err := hp.ReadString()
+                               if err != nil {
+                                       return err
+                               }
+                               value, err := hp.ReadString()
+                               if err != nil {
+                                       return err
+                               }
+                               headers[key] = value
+                       }
+               } else {
+                       // Skip reading info section on the first
+                       // unsupported info type.
+                       break
+               }
+       }
+       t.readHeaders = headers
+
+       return nil
+}
+
+func (t *THeaderTransport) needReadFrame() bool {
+       if t.clientType == clientUnknown {
+               // This is a new connection that's never read before.
+               return true
+       }
+       if t.isFramed() && t.frameReader == nil {
+               // We just finished the last frame.
+               return true
+       }
+       return false
+}
+
+func (t *THeaderTransport) Read(p []byte) (read int, err error) {
+       err = t.ReadFrame()
+       if err != nil {
+               return
+       }
+       if t.frameReader != nil {
+               read, err = t.frameReader.Read(p)
+               if err == io.EOF {
+                       err = t.endOfFrame()
+                       if err != nil {
+                               return
+                       }
+                       if read < len(p) {
+                               var nextRead int
+                               nextRead, err = t.Read(p[read:])
+                               read += nextRead
+                       }
+               }
+               return
+       }
+       return t.reader.Read(p)
+}
+
+// Write writes data to the write buffer.
+//
+// You need to call Flush to actually write them to the transport.
+func (t *THeaderTransport) Write(p []byte) (int, error) {
+       return t.writeBuffer.Write(p)
+}
+
+// Flush writes the appropriate header and the write buffer to the underlying 
transport.
+func (t *THeaderTransport) Flush(ctx context.Context) error {
+       if t.writeBuffer.Len() == 0 {
+               return nil
+       }
+
+       defer t.writeBuffer.Reset()
+
+       switch t.clientType {
+       default:
+               fallthrough
+       case clientUnknown:
+               t.clientType = clientHeaders
+               fallthrough
+       case clientHeaders:
+               headers := NewTMemoryBuffer()
+               hp := NewTCompactProtocol(headers)
+               if _, err := hp.writeVarint32(int32(t.protocolID)); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               if _, err := hp.writeVarint32(int32(len(t.writeTransforms))); 
err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               for _, transform := range t.writeTransforms {
+                       if _, err := hp.writeVarint32(int32(transform)); err != 
nil {
+                               return NewTTransportExceptionFromError(err)
+                       }
+               }
+               if len(t.writeHeaders) > 0 {
+                       if _, err := hp.writeVarint32(int32(InfoKeyValue)); err 
!= nil {
+                               return NewTTransportExceptionFromError(err)
+                       }
+                       if _, err := 
hp.writeVarint32(int32(len(t.writeHeaders))); err != nil {
+                               return NewTTransportExceptionFromError(err)
+                       }
+                       for key, value := range t.writeHeaders {
+                               if err := hp.WriteString(key); err != nil {
+                                       return 
NewTTransportExceptionFromError(err)
+                               }
+                               if err := hp.WriteString(value); err != nil {
+                                       return 
NewTTransportExceptionFromError(err)
+                               }
+                       }
+               }
+               padding := 4 - headers.Len()%4
+               if padding < 4 {
+                       buf := t.buffer[:padding]
+                       for i := range buf {
+                               buf[i] = 0
+                       }
+                       if _, err := headers.Write(buf); err != nil {
+                               return NewTTransportExceptionFromError(err)
+                       }
+               }
+
+               var payload bytes.Buffer
+               meta := headerMeta{
+                       MagicFlags:   THeaderHeaderMagic + 
t.Flags&THeaderFlagsMask,
+                       SequenceID:   t.SequenceID,
+                       HeaderLength: uint16(headers.Len() / 4),
+               }
+               if err := binary.Write(&payload, binary.BigEndian, meta); err 
!= nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               if _, err := io.Copy(&payload, headers); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+
+               writer, err := NewTransformWriter(&payload, t.writeTransforms)
+               if err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               if _, err := io.Copy(writer, &t.writeBuffer); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               if err := writer.Close(); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+
+               // First write frame length
+               buf := t.buffer[:size32]
+               binary.BigEndian.PutUint32(buf, uint32(payload.Len()))
+               if _, err := t.transport.Write(buf); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               // Then write the payload
+               if _, err := io.Copy(t.transport, &payload); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+
+       case clientFramedBinary, clientFramedCompact:
+               buf := t.buffer[:size32]
+               binary.BigEndian.PutUint32(buf, uint32(t.writeBuffer.Len()))
+               if _, err := t.transport.Write(buf); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+               fallthrough
+       case clientUnframedBinary, clientUnframedCompact:
+               if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil {
+                       return NewTTransportExceptionFromError(err)
+               }
+       }
+
+       select {
+       default:
+       case <-ctx.Done():
+               return NewTTransportExceptionFromError(ctx.Err())
+       }
+
+       return t.transport.Flush(ctx)
+}
+
+// Close closes the transport, along with its underlying transport.
+func (t *THeaderTransport) Close() error {
+       if err := t.Flush(context.Background()); err != nil {
+               return err
+       }
+       return t.transport.Close()
+}
+
+// RemainingBytes calls underlying transport's RemainingBytes.
+//
+// Even in framed cases, because of all the possible compression transforms
+// involved, the remaining frame size is likely to be different from the actual
+// remaining readable bytes, so we don't bother to keep tracking the remaining
+// frame size by ourselves and just use the underlying transport's
+// RemainingBytes directly.
+func (t *THeaderTransport) RemainingBytes() uint64 {
+       return t.transport.RemainingBytes()
+}
+
+// GetReadHeaders returns the THeaderMap read from transport.
+func (t *THeaderTransport) GetReadHeaders() THeaderMap {
+       return t.readHeaders
+}
+
+// SetWriteHeader sets a header for write.
+func (t *THeaderTransport) SetWriteHeader(key, value string) {
+       t.writeHeaders[key] = value
+}
+
+// ClearWriteHeaders clears all write headers previously set.
+func (t *THeaderTransport) ClearWriteHeaders() {
+       t.writeHeaders = make(THeaderMap)
+}
+
+// AddTransform add a transform for writing.
+func (t *THeaderTransport) AddTransform(transform THeaderTransformID) error {
+       if !supportedTransformIDs[transform] {
+               return NewTProtocolExceptionWithType(
+                       NOT_IMPLEMENTED,
+                       fmt.Errorf("THeaderTransformID %d not supported", 
transform),
+               )
+       }
+       t.writeTransforms = append(t.writeTransforms, transform)
+       return nil
+}
+
+// Protocol returns the wrapped protocol id used in this THeaderTransport.
+func (t *THeaderTransport) Protocol() THeaderProtocolID {
+       switch t.clientType {
+       default:
+               return t.protocolID
+       case clientFramedBinary, clientUnframedBinary:
+               return THeaderProtocolBinary
+       case clientFramedCompact, clientUnframedCompact:
+               return THeaderProtocolCompact
+       }
+}
+
+func (t *THeaderTransport) isFramed() bool {
+       switch t.clientType {
+       default:
+               return false
+       case clientHeaders, clientFramedBinary, clientFramedCompact:
+               return true
+       }
+}
diff --git a/lib/go/thrift/header_transport_test.go 
b/lib/go/thrift/header_transport_test.go
new file mode 100644
index 0000000..94af010
--- /dev/null
+++ b/lib/go/thrift/header_transport_test.go
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+       "context"
+       "io/ioutil"
+       "testing"
+)
+
+func TestTHeaderHeadersReadWrite(t *testing.T) {
+       trans := NewTMemoryBuffer()
+       reader := NewTHeaderTransport(trans)
+       writer := NewTHeaderTransport(trans)
+
+       const key1 = "key1"
+       const value1 = "value1"
+       const key2 = "key2"
+       const value2 = "value2"
+       const payload1 = "hello, world1\n"
+       const payload2 = "hello, world2\n"
+
+       // Write
+       if err := writer.AddTransform(TransformZlib); err != nil {
+               t.Fatalf(
+                       "writer.AddTransform(TransformZlib) returned error: %v",
+                       err,
+               )
+       }
+       // Use double zlib to make sure that we close them in the right order.
+       if err := writer.AddTransform(TransformZlib); err != nil {
+               t.Fatalf(
+                       "writer.AddTransform(TransformZlib) returned error: %v",
+                       err,
+               )
+       }
+       if err := writer.AddTransform(TransformNone); err != nil {
+               t.Fatalf(
+                       "writer.AddTransform(TransformNone) returned error: %v",
+                       err,
+               )
+       }
+       writer.SetWriteHeader(key1, value1)
+       writer.SetWriteHeader(key2, value2)
+       if _, err := writer.Write([]byte(payload1)); err != nil {
+               t.Errorf("writer.Write returned error: %v", err)
+       }
+       if err := writer.Flush(context.Background()); err != nil {
+               t.Errorf("writer.Flush returned error: %v", err)
+       }
+       if _, err := writer.Write([]byte(payload2)); err != nil {
+               t.Errorf("writer.Write returned error: %v", err)
+       }
+       if err := writer.Flush(context.Background()); err != nil {
+               t.Errorf("writer.Flush returned error: %v", err)
+       }
+
+       // Read
+       read, err := ioutil.ReadAll(reader)
+       if err != nil {
+               t.Errorf("Read returned error: %v", err)
+       }
+       if string(read) != payload1+payload2 {
+               t.Errorf(
+                       "Read content expected %q, got %q",
+                       payload1+payload2,
+                       read,
+               )
+       }
+       if prot := reader.Protocol(); prot != THeaderProtocolBinary {
+               t.Errorf(
+                       "reader.Protocol() expected %d, got %d",
+                       THeaderProtocolBinary,
+                       prot,
+               )
+       }
+       if reader.clientType != clientHeaders {
+               t.Errorf(
+                       "reader.clientType expected %d, got %d",
+                       clientHeaders,
+                       reader.clientType,
+               )
+       }
+       headers := reader.GetReadHeaders()
+       if len(headers) != 2 || headers[key1] != value1 || headers[key2] != 
value2 {
+               t.Errorf(
+                       "reader.GetReadHeaders() expected size 2, actual 
content: %+v",
+                       headers,
+               )
+       }
+}
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index 6035802..7db36c2 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -187,12 +187,24 @@ func (p *TSimpleServer) processRequests(client 
TTransport) error {
        if err != nil {
                return err
        }
-       outputTransport, err := p.outputTransportFactory.GetTransport(client)
-       if err != nil {
-               return err
-       }
        inputProtocol := p.inputProtocolFactory.GetProtocol(inputTransport)
-       outputProtocol := p.outputProtocolFactory.GetProtocol(outputTransport)
+       var outputTransport TTransport
+       var outputProtocol TProtocol
+
+       // for THeaderProtocol, we must use the same protocol instance for
+       // input and output so that the response is in the same dialect that
+       // the server detected the request was in.
+       if _, ok := inputProtocol.(*THeaderProtocol); ok {
+               outputProtocol = inputProtocol
+       } else {
+               oTrans, err := p.outputTransportFactory.GetTransport(client)
+               if err != nil {
+                       return err
+               }
+               outputTransport = oTrans
+               outputProtocol = 
p.outputProtocolFactory.GetProtocol(outputTransport)
+       }
+
        defer func() {
                if e := recover(); e != nil {
                        log.Printf("panic in processor: %s: %s", e, 
debug.Stack())
diff --git a/test/go/src/bin/testserver/main.go 
b/test/go/src/bin/testserver/main.go
index ca2d967..6fc1185 100644
--- a/test/go/src/bin/testserver/main.go
+++ b/test/go/src/bin/testserver/main.go
@@ -32,7 +32,7 @@ var host = flag.String("host", "localhost", "Host to connect")
 var port = flag.Int64("port", 9090, "Port number to connect")
 var domain_socket = flag.String("domain-socket", "", "Domain Socket (e.g. 
/tmp/ThriftTest.thrift), instead of host and port")
 var transport = flag.String("transport", "buffered", "Transport: buffered, 
framed, http, zlib")
-var protocol = flag.String("protocol", "binary", "Protocol: binary, compact, 
json")
+var protocol = flag.String("protocol", "binary", "Protocol: binary, compact, 
json, header")
 var ssl = flag.Bool("ssl", false, "Encrypted Transport using SSL")
 var zlib = flag.Bool("zlib", false, "Wrapped Transport using Zlib")
 var certPath = flag.String("certPath", "keys", "Directory that contains SSL 
certificates")
@@ -43,7 +43,7 @@ func main() {
        processor, serverTransport, transportFactory, protocolFactory, err := 
common.GetServerParams(*host, *port, *domain_socket, *transport, *protocol, 
*ssl, *certPath, common.PrintingHandler)
 
        if err != nil {
-               log.Fatalf("Unable to process server params: ", err)
+               log.Fatalf("Unable to process server params: %v", err)
        }
 
        if *transport == "http" {
diff --git a/test/go/src/common/client.go b/test/go/src/common/client.go
index 236ce43..ed820ae 100644
--- a/test/go/src/common/client.go
+++ b/test/go/src/common/client.go
@@ -55,6 +55,8 @@ func StartClient(
                protocolFactory = thrift.NewTJSONProtocolFactory()
        case "binary":
                protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
+       case "header":
+               protocolFactory = thrift.NewTHeaderProtocolFactory()
        default:
                return nil, nil, fmt.Errorf("Invalid protocol specified %s", 
protocol)
        }
diff --git a/test/go/src/common/server.go b/test/go/src/common/server.go
index 5ac4400..c6674ae 100644
--- a/test/go/src/common/server.go
+++ b/test/go/src/common/server.go
@@ -60,6 +60,8 @@ func GetServerParams(
                protocolFactory = thrift.NewTJSONProtocolFactory()
        case "binary":
                protocolFactory = thrift.NewTBinaryProtocolFactoryDefault()
+       case "header":
+               protocolFactory = thrift.NewTHeaderProtocolFactory()
        default:
                return nil, nil, nil, nil, fmt.Errorf("Invalid protocol 
specified %s", protocol)
        }
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index 6c61d5a..9f97754 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -31,12 +31,16 @@
   "cpp-go_binary_http-ip-ssl",
   "cpp-go_compact_http-ip",
   "cpp-go_compact_http-ip-ssl",
+  "cpp-go_header_http-ip",
+  "cpp-go_header_http-ip-ssl",
   "cpp-go_json_http-ip",
   "cpp-go_json_http-ip-ssl",
   "cpp-go_multi-binary_http-ip",
   "cpp-go_multi-binary_http-ip-ssl",
   "cpp-go_multic-compact_http-ip",
   "cpp-go_multic-compact_http-ip-ssl",
+  "cpp-go_multih-header_http-ip",
+  "cpp-go_multih-header_http-ip-ssl",
   "cpp-go_multij-json_http-ip",
   "cpp-go_multij-json_http-ip-ssl",
   "cpp-java_binary_http-ip",
@@ -304,6 +308,8 @@
   "go-cpp_binary_http-ip-ssl",
   "go-cpp_compact_http-ip",
   "go-cpp_compact_http-ip-ssl",
+  "go-cpp_header_http-ip",
+  "go-cpp_header_http-ip-ssl",
   "go-cpp_json_http-ip",
   "go-cpp_json_http-ip-ssl",
   "go-d_binary_http-ip",
@@ -357,6 +363,8 @@
   "nodejs-go_binary_http-ip-ssl",
   "nodejs-go_compact_http-ip",
   "nodejs-go_compact_http-ip-ssl",
+  "nodejs-go_header_http-ip",
+  "nodejs-go_header_http-ip-ssl",
   "nodejs-go_json_http-ip",
   "nodejs-go_json_http-ip-ssl",
   "nodejs-hs_binary_http-ip",
diff --git a/test/tests.json b/test/tests.json
index a4680d1..851244e 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -115,7 +115,8 @@
     "protocols": [
       "binary",
       "compact",
-      "json"
+      "json",
+      "header"
     ],
     "workdir": "go/bin"
   },

Reply via email to