This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 64c2a4b  THRIFT-5294: Fix panic in go TSimpleJSONProtocol
64c2a4b is described below

commit 64c2a4b87ab356e05033045492e51f1ad73a795b
Author: Yuxuan 'fishy' Wang <[email protected]>
AuthorDate: Sat Oct 10 18:39:32 2020 -0700

    THRIFT-5294: Fix panic in go TSimpleJSONProtocol
    
    Client: go
    
    In go library's TSimpleJSONProtocol and TJSONProtocol implementations,
    we use slices as stacks for context info, but didn't do proper boundary
    check when peeking/popping, result in it might panic with using -1 as
    slice index in certain cases of calling Write*End without matching
    Write*Begin before.
    
    Refactor the code to properly implement the stack, and return a
    TProtocolException instead on those cases.
    
    Also add unit tests for all protocols. The unit tests shown that
    TCompactProtocol.[Read|Write]StructEnd would also panic with unmatched
    Begin calls, so fix them as well.
---
 lib/go/thrift/compact_protocol.go          |   7 ++
 lib/go/thrift/json_protocol.go             |   4 +-
 lib/go/thrift/json_protocol_test.go        |   4 +
 lib/go/thrift/protocol_test.go             |  89 ++++++++++++++++
 lib/go/thrift/simple_json_protocol.go      | 163 ++++++++++++++++++-----------
 lib/go/thrift/simple_json_protocol_test.go |  55 ++++++++++
 6 files changed, 261 insertions(+), 61 deletions(-)

diff --git a/lib/go/thrift/compact_protocol.go 
b/lib/go/thrift/compact_protocol.go
index 8510f1f..a016195 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -22,6 +22,7 @@ package thrift
 import (
        "context"
        "encoding/binary"
+       "errors"
        "fmt"
        "io"
        "math"
@@ -158,6 +159,9 @@ func (p *TCompactProtocol) WriteStructBegin(ctx 
context.Context, name string) er
 // this as an opportunity to pop the last field from the current struct off
 // of the field stack.
 func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
+       if len(p.lastField) <= 0 {
+               return NewTProtocolExceptionWithType(INVALID_DATA, 
errors.New("WriteStructEnd called without matching WriteStructBegin call 
before"))
+       }
        p.lastFieldId = p.lastField[len(p.lastField)-1]
        p.lastField = p.lastField[:len(p.lastField)-1]
        return nil
@@ -386,6 +390,9 @@ func (p *TCompactProtocol) ReadStructBegin(ctx 
context.Context) (name string, er
 // this struct from the field stack.
 func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
        // consume the last field we read off the wire.
+       if len(p.lastField) <= 0 {
+               return NewTProtocolExceptionWithType(INVALID_DATA, 
errors.New("ReadStructEnd called without matching ReadStructBegin call before"))
+       }
        p.lastFieldId = p.lastField[len(p.lastField)-1]
        p.lastField = p.lastField[:len(p.lastField)-1]
        return nil
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 9a9328d..edc49cc 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -41,8 +41,8 @@ type TJSONProtocol struct {
 // Constructor
 func NewTJSONProtocol(t TTransport) *TJSONProtocol {
        v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
-       v.parseContextStack = append(v.parseContextStack, 
int(_CONTEXT_IN_TOPLEVEL))
-       v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+       v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+       v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
        return v
 }
 
diff --git a/lib/go/thrift/json_protocol_test.go 
b/lib/go/thrift/json_protocol_test.go
index 333d383..39e52d1 100644
--- a/lib/go/thrift/json_protocol_test.go
+++ b/lib/go/thrift/json_protocol_test.go
@@ -648,3 +648,7 @@ func TestWriteJSONProtocolMap(t *testing.T) {
        }
        trans.Close()
 }
+
+func TestTJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+       UnmatchedBeginEndProtocolTest(t, NewTJSONProtocolFactory())
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index c1c67e8..caac78e 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -217,6 +217,10 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory 
TProtocolFactory) {
                ReadWriteByte(t, p, trans)
                trans.Close()
        }
+
+       t.Run("UnmatchedBeginEnd", func(t *testing.T) {
+               UnmatchedBeginEndProtocolTest(t, protocolFactory)
+       })
 }
 
 func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
@@ -515,3 +519,88 @@ func ReadWriteBinary(t testing.TB, p TProtocol, trans 
TTransport) {
                }
        }
 }
+
+func UnmatchedBeginEndProtocolTest(t *testing.T, protocolFactory 
TProtocolFactory) {
+       // NOTE: not all protocol implementations do strict state check to
+       // return an error on unmatched Begin/End calls.
+       // This test is only meant to make sure that those unmatched Begin/End
+       // calls won't cause panic. There's no real "test" here.
+       trans := NewTMemoryBuffer()
+       t.Run("Read", func(t *testing.T) {
+               t.Run("Message", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadMessageEnd(context.Background())
+                       p.ReadMessageEnd(context.Background())
+               })
+               t.Run("Struct", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadStructEnd(context.Background())
+                       p.ReadStructEnd(context.Background())
+               })
+               t.Run("Field", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadFieldEnd(context.Background())
+                       p.ReadFieldEnd(context.Background())
+               })
+               t.Run("Map", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadMapEnd(context.Background())
+                       p.ReadMapEnd(context.Background())
+               })
+               t.Run("List", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadListEnd(context.Background())
+                       p.ReadListEnd(context.Background())
+               })
+               t.Run("Set", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.ReadSetEnd(context.Background())
+                       p.ReadSetEnd(context.Background())
+               })
+       })
+       t.Run("Write", func(t *testing.T) {
+               t.Run("Message", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteMessageEnd(context.Background())
+                       p.WriteMessageEnd(context.Background())
+               })
+               t.Run("Struct", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteStructEnd(context.Background())
+                       p.WriteStructEnd(context.Background())
+               })
+               t.Run("Field", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteFieldEnd(context.Background())
+                       p.WriteFieldEnd(context.Background())
+               })
+               t.Run("Map", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteMapEnd(context.Background())
+                       p.WriteMapEnd(context.Background())
+               })
+               t.Run("List", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteListEnd(context.Background())
+                       p.WriteListEnd(context.Background())
+               })
+               t.Run("Set", func(t *testing.T) {
+                       trans.Reset()
+                       p := protocolFactory.GetProtocol(trans)
+                       p.WriteSetEnd(context.Background())
+                       p.WriteSetEnd(context.Background())
+               })
+       })
+       trans.Close()
+}
diff --git a/lib/go/thrift/simple_json_protocol.go 
b/lib/go/thrift/simple_json_protocol.go
index d101b99..e94b44b 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -25,6 +25,7 @@ import (
        "context"
        "encoding/base64"
        "encoding/json"
+       "errors"
        "fmt"
        "io"
        "math"
@@ -34,12 +35,13 @@ import (
 type _ParseContext int
 
 const (
-       _CONTEXT_IN_TOPLEVEL          _ParseContext = 1
-       _CONTEXT_IN_LIST_FIRST        _ParseContext = 2
-       _CONTEXT_IN_LIST              _ParseContext = 3
-       _CONTEXT_IN_OBJECT_FIRST      _ParseContext = 4
-       _CONTEXT_IN_OBJECT_NEXT_KEY   _ParseContext = 5
-       _CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6
+       _CONTEXT_INVALID              _ParseContext = iota
+       _CONTEXT_IN_TOPLEVEL                        // 1
+       _CONTEXT_IN_LIST_FIRST                      // 2
+       _CONTEXT_IN_LIST                            // 3
+       _CONTEXT_IN_OBJECT_FIRST                    // 4
+       _CONTEXT_IN_OBJECT_NEXT_KEY                 // 5
+       _CONTEXT_IN_OBJECT_NEXT_VALUE               // 6
 )
 
 func (p _ParseContext) String() string {
@@ -60,6 +62,32 @@ func (p _ParseContext) String() string {
        return "UNKNOWN-PARSE-CONTEXT"
 }
 
+type jsonContextStack []_ParseContext
+
+func (s *jsonContextStack) push(v _ParseContext) {
+       *s = append(*s, v)
+}
+
+func (s jsonContextStack) peek() (v _ParseContext, ok bool) {
+       l := len(s)
+       if l <= 0 {
+               return
+       }
+       return s[l-1], true
+}
+
+func (s *jsonContextStack) pop() (v _ParseContext, ok bool) {
+       l := len(*s)
+       if l <= 0 {
+               return
+       }
+       v = (*s)[l-1]
+       *s = (*s)[0 : l-1]
+       return v, true
+}
+
+var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, 
errors.New("Unexpected empty json protocol context stack"))
+
 // Simple JSON protocol implementation for thrift.
 //
 // This protocol produces/consumes a simple output format
@@ -69,8 +97,8 @@ func (p _ParseContext) String() string {
 type TSimpleJSONProtocol struct {
        trans TTransport
 
-       parseContextStack []int
-       dumpContext       []int
+       parseContextStack jsonContextStack
+       dumpContext       jsonContextStack
 
        writer *bufio.Writer
        reader *bufio.Reader
@@ -82,8 +110,8 @@ func NewTSimpleJSONProtocol(t TTransport) 
*TSimpleJSONProtocol {
                writer: bufio.NewWriter(t),
                reader: bufio.NewReader(t),
        }
-       v.parseContextStack = append(v.parseContextStack, 
int(_CONTEXT_IN_TOPLEVEL))
-       v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+       v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+       v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
        return v
 }
 
@@ -549,41 +577,41 @@ func (p *TSimpleJSONProtocol) Transport() TTransport {
 }
 
 func (p *TSimpleJSONProtocol) OutputPreValue() error {
-       cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+       cxt, ok := p.dumpContext.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        switch cxt {
        case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY:
                if _, e := p.write(JSON_COMMA); e != nil {
                        return NewTProtocolException(e)
                }
-               break
        case _CONTEXT_IN_OBJECT_NEXT_VALUE:
                if _, e := p.write(JSON_COLON); e != nil {
                        return NewTProtocolException(e)
                }
-               break
        }
        return nil
 }
 
 func (p *TSimpleJSONProtocol) OutputPostValue() error {
-       cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+       cxt, ok := p.dumpContext.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        switch cxt {
        case _CONTEXT_IN_LIST_FIRST:
-               p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-               p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST))
-               break
+               p.dumpContext.pop()
+               p.dumpContext.push(_CONTEXT_IN_LIST)
        case _CONTEXT_IN_OBJECT_FIRST:
-               p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-               p.dumpContext = append(p.dumpContext, 
int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-               break
+               p.dumpContext.pop()
+               p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
        case _CONTEXT_IN_OBJECT_NEXT_KEY:
-               p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-               p.dumpContext = append(p.dumpContext, 
int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-               break
+               p.dumpContext.pop()
+               p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
        case _CONTEXT_IN_OBJECT_NEXT_VALUE:
-               p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-               p.dumpContext = append(p.dumpContext, 
int(_CONTEXT_IN_OBJECT_NEXT_KEY))
-               break
+               p.dumpContext.pop()
+               p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
        }
        return nil
 }
@@ -598,10 +626,13 @@ func (p *TSimpleJSONProtocol) OutputBool(value bool) 
error {
        } else {
                v = string(JSON_FALSE)
        }
-       switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+       cxt, ok := p.dumpContext.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
+       switch cxt {
        case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
                v = jsonQuote(v)
-       default:
        }
        if e := p.OutputStringData(v); e != nil {
                return e
@@ -631,11 +662,14 @@ func (p *TSimpleJSONProtocol) OutputF64(value float64) 
error {
        } else if math.IsInf(value, -1) {
                v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + 
string(JSON_QUOTE)
        } else {
+               cxt, ok := p.dumpContext.peek()
+               if !ok {
+                       return errEmptyJSONContextStack
+               }
                v = strconv.FormatFloat(value, 'g', -1, 64)
-               switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+               switch cxt {
                case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
                        v = string(JSON_QUOTE) + v + string(JSON_QUOTE)
-               default:
                }
        }
        if e := p.OutputStringData(v); e != nil {
@@ -648,11 +682,14 @@ func (p *TSimpleJSONProtocol) OutputI64(value int64) 
error {
        if e := p.OutputPreValue(); e != nil {
                return e
        }
+       cxt, ok := p.dumpContext.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        v := strconv.FormatInt(value, 10)
-       switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+       switch cxt {
        case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
                v = jsonQuote(v)
-       default:
        }
        if e := p.OutputStringData(v); e != nil {
                return e
@@ -682,7 +719,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error {
        if _, e := p.write(JSON_LBRACE); e != nil {
                return NewTProtocolException(e)
        }
-       p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST))
+       p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST)
        return nil
 }
 
@@ -690,7 +727,10 @@ func (p *TSimpleJSONProtocol) OutputObjectEnd() error {
        if _, e := p.write(JSON_RBRACE); e != nil {
                return NewTProtocolException(e)
        }
-       p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+       _, ok := p.dumpContext.pop()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        if e := p.OutputPostValue(); e != nil {
                return e
        }
@@ -704,7 +744,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error {
        if _, e := p.write(JSON_LBRACKET); e != nil {
                return NewTProtocolException(e)
        }
-       p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST))
+       p.dumpContext.push(_CONTEXT_IN_LIST_FIRST)
        return nil
 }
 
@@ -712,7 +752,10 @@ func (p *TSimpleJSONProtocol) OutputListEnd() error {
        if _, e := p.write(JSON_RBRACKET); e != nil {
                return NewTProtocolException(e)
        }
-       p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+       _, ok := p.dumpContext.pop()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        if e := p.OutputPostValue(); e != nil {
                return e
        }
@@ -736,7 +779,10 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
        if e := p.readNonSignificantWhitespace(); e != nil {
                return NewTProtocolException(e)
        }
-       cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+       cxt, ok := p.parseContextStack.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        b, _ := p.reader.Peek(1)
        switch cxt {
        case _CONTEXT_IN_LIST:
@@ -755,7 +801,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
                                return 
NewTProtocolExceptionWithType(INVALID_DATA, e)
                        }
                }
-               break
        case _CONTEXT_IN_OBJECT_NEXT_KEY:
                if len(b) > 0 {
                        switch b[0] {
@@ -772,7 +817,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
                                return 
NewTProtocolExceptionWithType(INVALID_DATA, e)
                        }
                }
-               break
        case _CONTEXT_IN_OBJECT_NEXT_VALUE:
                if len(b) > 0 {
                        switch b[0] {
@@ -787,7 +831,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
                                return 
NewTProtocolExceptionWithType(INVALID_DATA, e)
                        }
                }
-               break
        }
        return nil
 }
@@ -796,20 +839,20 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error {
        if e := p.readNonSignificantWhitespace(); e != nil {
                return NewTProtocolException(e)
        }
-       cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+       cxt, ok := p.parseContextStack.peek()
+       if !ok {
+               return errEmptyJSONContextStack
+       }
        switch cxt {
        case _CONTEXT_IN_LIST_FIRST:
-               p.parseContextStack = 
p.parseContextStack[:len(p.parseContextStack)-1]
-               p.parseContextStack = append(p.parseContextStack, 
int(_CONTEXT_IN_LIST))
-               break
+               p.parseContextStack.pop()
+               p.parseContextStack.push(_CONTEXT_IN_LIST)
        case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
-               p.parseContextStack = 
p.parseContextStack[:len(p.parseContextStack)-1]
-               p.parseContextStack = append(p.parseContextStack, 
int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-               break
+               p.parseContextStack.pop()
+               p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
        case _CONTEXT_IN_OBJECT_NEXT_VALUE:
-               p.parseContextStack = 
p.parseContextStack[:len(p.parseContextStack)-1]
-               p.parseContextStack = append(p.parseContextStack, 
int(_CONTEXT_IN_OBJECT_NEXT_KEY))
-               break
+               p.parseContextStack.pop()
+               p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
        }
        return nil
 }
@@ -962,7 +1005,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, 
error) {
        }
        if len(b) > 0 && b[0] == JSON_LBRACE[0] {
                p.reader.ReadByte()
-               p.parseContextStack = append(p.parseContextStack, 
int(_CONTEXT_IN_OBJECT_FIRST))
+               p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST)
                return false, nil
        } else if p.safePeekContains(JSON_NULL) {
                return true, nil
@@ -975,7 +1018,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
        if isNull, err := p.readIfNull(); isNull || err != nil {
                return err
        }
-       cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+       cxt, _ := p.parseContextStack.peek()
        if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != 
_CONTEXT_IN_OBJECT_NEXT_KEY) {
                e := fmt.Errorf("Expected to be in the Object Context, but not 
in Object Context (%d)", cxt)
                return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -993,7 +1036,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
                        break
                }
        }
-       p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+       p.parseContextStack.pop()
        return p.ParsePostValue()
 }
 
@@ -1007,7 +1050,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull 
bool, err error) {
                return false, err
        }
        if len(b) >= 1 && b[0] == JSON_LBRACKET[0] {
-               p.parseContextStack = append(p.parseContextStack, 
int(_CONTEXT_IN_LIST_FIRST))
+               p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST)
                p.reader.ReadByte()
                isNull = false
        } else if p.safePeekContains(JSON_NULL) {
@@ -1036,7 +1079,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
        if isNull, err := p.readIfNull(); isNull || err != nil {
                return err
        }
-       cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+       cxt, _ := p.parseContextStack.peek()
        if cxt != _CONTEXT_IN_LIST {
                e := fmt.Errorf("Expected to be in the List Context, but not in 
List Context (%d)", cxt)
                return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1054,8 +1097,10 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
                        break
                }
        }
-       p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
-       if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == 
_CONTEXT_IN_TOPLEVEL {
+       p.parseContextStack.pop()
+       if cxt, ok := p.parseContextStack.peek(); !ok {
+               return errEmptyJSONContextStack
+       } else if cxt == _CONTEXT_IN_TOPLEVEL {
                return nil
        }
        return p.ParsePostValue()
@@ -1308,8 +1353,8 @@ func (p *TSimpleJSONProtocol) safePeekContains(b []byte) 
bool {
 
 // Reset the context stack to its initial state.
 func (p *TSimpleJSONProtocol) resetContextStack() {
-       p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)}
-       p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)}
+       p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
+       p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
 }
 
 func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
diff --git a/lib/go/thrift/simple_json_protocol_test.go 
b/lib/go/thrift/simple_json_protocol_test.go
index 986fff2..89753c6 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -736,3 +736,58 @@ func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) {
                t.Fatalf("Should not match at test 3")
        }
 }
+
+func TestJSONContextStack(t *testing.T) {
+       var stack jsonContextStack
+       t.Run("empty-peek", func(t *testing.T) {
+               v, ok := stack.peek()
+               if ok {
+                       t.Error("peek() on empty should return ok: false")
+               }
+               expected := _CONTEXT_INVALID
+               if v != expected {
+                       t.Errorf("Expected value from peek() to be %v(%d), got 
%v(%d)", expected, expected, v, v)
+               }
+       })
+       t.Run("empty-pop", func(t *testing.T) {
+               v, ok := stack.pop()
+               if ok {
+                       t.Error("pop() on empty should return ok: false")
+               }
+               expected := _CONTEXT_INVALID
+               if v != expected {
+                       t.Errorf("Expected value from pop() to be %v(%d), got 
%v(%d)", expected, expected, v, v)
+               }
+       })
+       t.Run("push-peek-pop", func(t *testing.T) {
+               expected := _CONTEXT_INVALID
+               stack.push(expected)
+               if len(stack) != 1 {
+                       t.Errorf("Expected stack to be as size 1 after push, 
got %#v", stack)
+               }
+               v, ok := stack.peek()
+               if !ok {
+                       t.Error("peek() on non-empty should return ok: true")
+               }
+               if v != expected {
+                       t.Errorf("Expected value from peek() to be %v(%d), got 
%v(%d)", expected, expected, v, v)
+               }
+               if len(stack) != 1 {
+                       t.Errorf("Expected peek() to be read-only, got %#v", 
stack)
+               }
+               v, ok = stack.pop()
+               if !ok {
+                       t.Error("pop() on non-empty should return ok: true")
+               }
+               if v != expected {
+                       t.Errorf("Expected value from pop() to be %v(%d), got 
%v(%d)", expected, expected, v, v)
+               }
+               if len(stack) != 0 {
+                       t.Errorf("Expected pop() to empty the stack, got %#v", 
stack)
+               }
+       })
+}
+
+func TestTSimpleJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+       UnmatchedBeginEndProtocolTest(t, NewTSimpleJSONProtocolFactory())
+}

Reply via email to