You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2020/10/14 17:14:18 UTC

[thrift] branch master updated: THRIFT-5294: Fix panic in go TSimpleJSONProtocol

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 64c2a4b  THRIFT-5294: Fix panic in go TSimpleJSONProtocol
64c2a4b is described below

commit 64c2a4b87ab356e05033045492e51f1ad73a795b
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Sat Oct 10 18:39:32 2020 -0700

    THRIFT-5294: Fix panic in go TSimpleJSONProtocol
    
    Client: go
    
    In go library's TSimpleJSONProtocol and TJSONProtocol implementations,
    we use slices as stacks for context info, but didn't do proper boundary
    check when peeking/popping, result in it might panic with using -1 as
    slice index in certain cases of calling Write*End without matching
    Write*Begin before.
    
    Refactor the code to properly implement the stack, and return a
    TProtocolException instead on those cases.
    
    Also add unit tests for all protocols. The unit tests shown that
    TCompactProtocol.[Read|Write]StructEnd would also panic with unmatched
    Begin calls, so fix them as well.
---
 lib/go/thrift/compact_protocol.go          |   7 ++
 lib/go/thrift/json_protocol.go             |   4 +-
 lib/go/thrift/json_protocol_test.go        |   4 +
 lib/go/thrift/protocol_test.go             |  89 ++++++++++++++++
 lib/go/thrift/simple_json_protocol.go      | 163 ++++++++++++++++++-----------
 lib/go/thrift/simple_json_protocol_test.go |  55 ++++++++++
 6 files changed, 261 insertions(+), 61 deletions(-)

diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index 8510f1f..a016195 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -22,6 +22,7 @@ package thrift
 import (
 	"context"
 	"encoding/binary"
+	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -158,6 +159,9 @@ func (p *TCompactProtocol) WriteStructBegin(ctx context.Context, name string) er
 // this as an opportunity to pop the last field from the current struct off
 // of the field stack.
 func (p *TCompactProtocol) WriteStructEnd(ctx context.Context) error {
+	if len(p.lastField) <= 0 {
+		return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("WriteStructEnd called without matching WriteStructBegin call before"))
+	}
 	p.lastFieldId = p.lastField[len(p.lastField)-1]
 	p.lastField = p.lastField[:len(p.lastField)-1]
 	return nil
@@ -386,6 +390,9 @@ func (p *TCompactProtocol) ReadStructBegin(ctx context.Context) (name string, er
 // this struct from the field stack.
 func (p *TCompactProtocol) ReadStructEnd(ctx context.Context) error {
 	// consume the last field we read off the wire.
+	if len(p.lastField) <= 0 {
+		return NewTProtocolExceptionWithType(INVALID_DATA, errors.New("ReadStructEnd called without matching ReadStructBegin call before"))
+	}
 	p.lastFieldId = p.lastField[len(p.lastField)-1]
 	p.lastField = p.lastField[:len(p.lastField)-1]
 	return nil
diff --git a/lib/go/thrift/json_protocol.go b/lib/go/thrift/json_protocol.go
index 9a9328d..edc49cc 100644
--- a/lib/go/thrift/json_protocol.go
+++ b/lib/go/thrift/json_protocol.go
@@ -41,8 +41,8 @@ type TJSONProtocol struct {
 // Constructor
 func NewTJSONProtocol(t TTransport) *TJSONProtocol {
 	v := &TJSONProtocol{TSimpleJSONProtocol: NewTSimpleJSONProtocol(t)}
-	v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
-	v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+	v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+	v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
 	return v
 }
 
diff --git a/lib/go/thrift/json_protocol_test.go b/lib/go/thrift/json_protocol_test.go
index 333d383..39e52d1 100644
--- a/lib/go/thrift/json_protocol_test.go
+++ b/lib/go/thrift/json_protocol_test.go
@@ -648,3 +648,7 @@ func TestWriteJSONProtocolMap(t *testing.T) {
 	}
 	trans.Close()
 }
+
+func TestTJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+	UnmatchedBeginEndProtocolTest(t, NewTJSONProtocolFactory())
+}
diff --git a/lib/go/thrift/protocol_test.go b/lib/go/thrift/protocol_test.go
index c1c67e8..caac78e 100644
--- a/lib/go/thrift/protocol_test.go
+++ b/lib/go/thrift/protocol_test.go
@@ -217,6 +217,10 @@ func ReadWriteProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
 		ReadWriteByte(t, p, trans)
 		trans.Close()
 	}
+
+	t.Run("UnmatchedBeginEnd", func(t *testing.T) {
+		UnmatchedBeginEndProtocolTest(t, protocolFactory)
+	})
 }
 
 func ReadWriteBool(t testing.TB, p TProtocol, trans TTransport) {
@@ -515,3 +519,88 @@ func ReadWriteBinary(t testing.TB, p TProtocol, trans TTransport) {
 		}
 	}
 }
+
+func UnmatchedBeginEndProtocolTest(t *testing.T, protocolFactory TProtocolFactory) {
+	// NOTE: not all protocol implementations do strict state check to
+	// return an error on unmatched Begin/End calls.
+	// This test is only meant to make sure that those unmatched Begin/End
+	// calls won't cause panic. There's no real "test" here.
+	trans := NewTMemoryBuffer()
+	t.Run("Read", func(t *testing.T) {
+		t.Run("Message", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadMessageEnd(context.Background())
+			p.ReadMessageEnd(context.Background())
+		})
+		t.Run("Struct", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadStructEnd(context.Background())
+			p.ReadStructEnd(context.Background())
+		})
+		t.Run("Field", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadFieldEnd(context.Background())
+			p.ReadFieldEnd(context.Background())
+		})
+		t.Run("Map", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadMapEnd(context.Background())
+			p.ReadMapEnd(context.Background())
+		})
+		t.Run("List", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadListEnd(context.Background())
+			p.ReadListEnd(context.Background())
+		})
+		t.Run("Set", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.ReadSetEnd(context.Background())
+			p.ReadSetEnd(context.Background())
+		})
+	})
+	t.Run("Write", func(t *testing.T) {
+		t.Run("Message", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteMessageEnd(context.Background())
+			p.WriteMessageEnd(context.Background())
+		})
+		t.Run("Struct", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteStructEnd(context.Background())
+			p.WriteStructEnd(context.Background())
+		})
+		t.Run("Field", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteFieldEnd(context.Background())
+			p.WriteFieldEnd(context.Background())
+		})
+		t.Run("Map", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteMapEnd(context.Background())
+			p.WriteMapEnd(context.Background())
+		})
+		t.Run("List", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteListEnd(context.Background())
+			p.WriteListEnd(context.Background())
+		})
+		t.Run("Set", func(t *testing.T) {
+			trans.Reset()
+			p := protocolFactory.GetProtocol(trans)
+			p.WriteSetEnd(context.Background())
+			p.WriteSetEnd(context.Background())
+		})
+	})
+	trans.Close()
+}
diff --git a/lib/go/thrift/simple_json_protocol.go b/lib/go/thrift/simple_json_protocol.go
index d101b99..e94b44b 100644
--- a/lib/go/thrift/simple_json_protocol.go
+++ b/lib/go/thrift/simple_json_protocol.go
@@ -25,6 +25,7 @@ import (
 	"context"
 	"encoding/base64"
 	"encoding/json"
+	"errors"
 	"fmt"
 	"io"
 	"math"
@@ -34,12 +35,13 @@ import (
 type _ParseContext int
 
 const (
-	_CONTEXT_IN_TOPLEVEL          _ParseContext = 1
-	_CONTEXT_IN_LIST_FIRST        _ParseContext = 2
-	_CONTEXT_IN_LIST              _ParseContext = 3
-	_CONTEXT_IN_OBJECT_FIRST      _ParseContext = 4
-	_CONTEXT_IN_OBJECT_NEXT_KEY   _ParseContext = 5
-	_CONTEXT_IN_OBJECT_NEXT_VALUE _ParseContext = 6
+	_CONTEXT_INVALID              _ParseContext = iota
+	_CONTEXT_IN_TOPLEVEL                        // 1
+	_CONTEXT_IN_LIST_FIRST                      // 2
+	_CONTEXT_IN_LIST                            // 3
+	_CONTEXT_IN_OBJECT_FIRST                    // 4
+	_CONTEXT_IN_OBJECT_NEXT_KEY                 // 5
+	_CONTEXT_IN_OBJECT_NEXT_VALUE               // 6
 )
 
 func (p _ParseContext) String() string {
@@ -60,6 +62,32 @@ func (p _ParseContext) String() string {
 	return "UNKNOWN-PARSE-CONTEXT"
 }
 
+type jsonContextStack []_ParseContext
+
+func (s *jsonContextStack) push(v _ParseContext) {
+	*s = append(*s, v)
+}
+
+func (s jsonContextStack) peek() (v _ParseContext, ok bool) {
+	l := len(s)
+	if l <= 0 {
+		return
+	}
+	return s[l-1], true
+}
+
+func (s *jsonContextStack) pop() (v _ParseContext, ok bool) {
+	l := len(*s)
+	if l <= 0 {
+		return
+	}
+	v = (*s)[l-1]
+	*s = (*s)[0 : l-1]
+	return v, true
+}
+
+var errEmptyJSONContextStack = NewTProtocolExceptionWithType(INVALID_DATA, errors.New("Unexpected empty json protocol context stack"))
+
 // Simple JSON protocol implementation for thrift.
 //
 // This protocol produces/consumes a simple output format
@@ -69,8 +97,8 @@ func (p _ParseContext) String() string {
 type TSimpleJSONProtocol struct {
 	trans TTransport
 
-	parseContextStack []int
-	dumpContext       []int
+	parseContextStack jsonContextStack
+	dumpContext       jsonContextStack
 
 	writer *bufio.Writer
 	reader *bufio.Reader
@@ -82,8 +110,8 @@ func NewTSimpleJSONProtocol(t TTransport) *TSimpleJSONProtocol {
 		writer: bufio.NewWriter(t),
 		reader: bufio.NewReader(t),
 	}
-	v.parseContextStack = append(v.parseContextStack, int(_CONTEXT_IN_TOPLEVEL))
-	v.dumpContext = append(v.dumpContext, int(_CONTEXT_IN_TOPLEVEL))
+	v.parseContextStack.push(_CONTEXT_IN_TOPLEVEL)
+	v.dumpContext.push(_CONTEXT_IN_TOPLEVEL)
 	return v
 }
 
@@ -549,41 +577,41 @@ func (p *TSimpleJSONProtocol) Transport() TTransport {
 }
 
 func (p *TSimpleJSONProtocol) OutputPreValue() error {
-	cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+	cxt, ok := p.dumpContext.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	switch cxt {
 	case _CONTEXT_IN_LIST, _CONTEXT_IN_OBJECT_NEXT_KEY:
 		if _, e := p.write(JSON_COMMA); e != nil {
 			return NewTProtocolException(e)
 		}
-		break
 	case _CONTEXT_IN_OBJECT_NEXT_VALUE:
 		if _, e := p.write(JSON_COLON); e != nil {
 			return NewTProtocolException(e)
 		}
-		break
 	}
 	return nil
 }
 
 func (p *TSimpleJSONProtocol) OutputPostValue() error {
-	cxt := _ParseContext(p.dumpContext[len(p.dumpContext)-1])
+	cxt, ok := p.dumpContext.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	switch cxt {
 	case _CONTEXT_IN_LIST_FIRST:
-		p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-		p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST))
-		break
+		p.dumpContext.pop()
+		p.dumpContext.push(_CONTEXT_IN_LIST)
 	case _CONTEXT_IN_OBJECT_FIRST:
-		p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-		p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-		break
+		p.dumpContext.pop()
+		p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
 	case _CONTEXT_IN_OBJECT_NEXT_KEY:
-		p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-		p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-		break
+		p.dumpContext.pop()
+		p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
 	case _CONTEXT_IN_OBJECT_NEXT_VALUE:
-		p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
-		p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
-		break
+		p.dumpContext.pop()
+		p.dumpContext.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
 	}
 	return nil
 }
@@ -598,10 +626,13 @@ func (p *TSimpleJSONProtocol) OutputBool(value bool) error {
 	} else {
 		v = string(JSON_FALSE)
 	}
-	switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+	cxt, ok := p.dumpContext.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
+	switch cxt {
 	case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
 		v = jsonQuote(v)
-	default:
 	}
 	if e := p.OutputStringData(v); e != nil {
 		return e
@@ -631,11 +662,14 @@ func (p *TSimpleJSONProtocol) OutputF64(value float64) error {
 	} else if math.IsInf(value, -1) {
 		v = string(JSON_QUOTE) + JSON_NEGATIVE_INFINITY + string(JSON_QUOTE)
 	} else {
+		cxt, ok := p.dumpContext.peek()
+		if !ok {
+			return errEmptyJSONContextStack
+		}
 		v = strconv.FormatFloat(value, 'g', -1, 64)
-		switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+		switch cxt {
 		case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
 			v = string(JSON_QUOTE) + v + string(JSON_QUOTE)
-		default:
 		}
 	}
 	if e := p.OutputStringData(v); e != nil {
@@ -648,11 +682,14 @@ func (p *TSimpleJSONProtocol) OutputI64(value int64) error {
 	if e := p.OutputPreValue(); e != nil {
 		return e
 	}
+	cxt, ok := p.dumpContext.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	v := strconv.FormatInt(value, 10)
-	switch _ParseContext(p.dumpContext[len(p.dumpContext)-1]) {
+	switch cxt {
 	case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
 		v = jsonQuote(v)
-	default:
 	}
 	if e := p.OutputStringData(v); e != nil {
 		return e
@@ -682,7 +719,7 @@ func (p *TSimpleJSONProtocol) OutputObjectBegin() error {
 	if _, e := p.write(JSON_LBRACE); e != nil {
 		return NewTProtocolException(e)
 	}
-	p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_OBJECT_FIRST))
+	p.dumpContext.push(_CONTEXT_IN_OBJECT_FIRST)
 	return nil
 }
 
@@ -690,7 +727,10 @@ func (p *TSimpleJSONProtocol) OutputObjectEnd() error {
 	if _, e := p.write(JSON_RBRACE); e != nil {
 		return NewTProtocolException(e)
 	}
-	p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+	_, ok := p.dumpContext.pop()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	if e := p.OutputPostValue(); e != nil {
 		return e
 	}
@@ -704,7 +744,7 @@ func (p *TSimpleJSONProtocol) OutputListBegin() error {
 	if _, e := p.write(JSON_LBRACKET); e != nil {
 		return NewTProtocolException(e)
 	}
-	p.dumpContext = append(p.dumpContext, int(_CONTEXT_IN_LIST_FIRST))
+	p.dumpContext.push(_CONTEXT_IN_LIST_FIRST)
 	return nil
 }
 
@@ -712,7 +752,10 @@ func (p *TSimpleJSONProtocol) OutputListEnd() error {
 	if _, e := p.write(JSON_RBRACKET); e != nil {
 		return NewTProtocolException(e)
 	}
-	p.dumpContext = p.dumpContext[:len(p.dumpContext)-1]
+	_, ok := p.dumpContext.pop()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	if e := p.OutputPostValue(); e != nil {
 		return e
 	}
@@ -736,7 +779,10 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
 	if e := p.readNonSignificantWhitespace(); e != nil {
 		return NewTProtocolException(e)
 	}
-	cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+	cxt, ok := p.parseContextStack.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	b, _ := p.reader.Peek(1)
 	switch cxt {
 	case _CONTEXT_IN_LIST:
@@ -755,7 +801,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
 				return NewTProtocolExceptionWithType(INVALID_DATA, e)
 			}
 		}
-		break
 	case _CONTEXT_IN_OBJECT_NEXT_KEY:
 		if len(b) > 0 {
 			switch b[0] {
@@ -772,7 +817,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
 				return NewTProtocolExceptionWithType(INVALID_DATA, e)
 			}
 		}
-		break
 	case _CONTEXT_IN_OBJECT_NEXT_VALUE:
 		if len(b) > 0 {
 			switch b[0] {
@@ -787,7 +831,6 @@ func (p *TSimpleJSONProtocol) ParsePreValue() error {
 				return NewTProtocolExceptionWithType(INVALID_DATA, e)
 			}
 		}
-		break
 	}
 	return nil
 }
@@ -796,20 +839,20 @@ func (p *TSimpleJSONProtocol) ParsePostValue() error {
 	if e := p.readNonSignificantWhitespace(); e != nil {
 		return NewTProtocolException(e)
 	}
-	cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+	cxt, ok := p.parseContextStack.peek()
+	if !ok {
+		return errEmptyJSONContextStack
+	}
 	switch cxt {
 	case _CONTEXT_IN_LIST_FIRST:
-		p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
-		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST))
-		break
+		p.parseContextStack.pop()
+		p.parseContextStack.push(_CONTEXT_IN_LIST)
 	case _CONTEXT_IN_OBJECT_FIRST, _CONTEXT_IN_OBJECT_NEXT_KEY:
-		p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
-		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_VALUE))
-		break
+		p.parseContextStack.pop()
+		p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_VALUE)
 	case _CONTEXT_IN_OBJECT_NEXT_VALUE:
-		p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
-		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_NEXT_KEY))
-		break
+		p.parseContextStack.pop()
+		p.parseContextStack.push(_CONTEXT_IN_OBJECT_NEXT_KEY)
 	}
 	return nil
 }
@@ -962,7 +1005,7 @@ func (p *TSimpleJSONProtocol) ParseObjectStart() (bool, error) {
 	}
 	if len(b) > 0 && b[0] == JSON_LBRACE[0] {
 		p.reader.ReadByte()
-		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_OBJECT_FIRST))
+		p.parseContextStack.push(_CONTEXT_IN_OBJECT_FIRST)
 		return false, nil
 	} else if p.safePeekContains(JSON_NULL) {
 		return true, nil
@@ -975,7 +1018,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
 	if isNull, err := p.readIfNull(); isNull || err != nil {
 		return err
 	}
-	cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+	cxt, _ := p.parseContextStack.peek()
 	if (cxt != _CONTEXT_IN_OBJECT_FIRST) && (cxt != _CONTEXT_IN_OBJECT_NEXT_KEY) {
 		e := fmt.Errorf("Expected to be in the Object Context, but not in Object Context (%d)", cxt)
 		return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -993,7 +1036,7 @@ func (p *TSimpleJSONProtocol) ParseObjectEnd() error {
 			break
 		}
 	}
-	p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
+	p.parseContextStack.pop()
 	return p.ParsePostValue()
 }
 
@@ -1007,7 +1050,7 @@ func (p *TSimpleJSONProtocol) ParseListBegin() (isNull bool, err error) {
 		return false, err
 	}
 	if len(b) >= 1 && b[0] == JSON_LBRACKET[0] {
-		p.parseContextStack = append(p.parseContextStack, int(_CONTEXT_IN_LIST_FIRST))
+		p.parseContextStack.push(_CONTEXT_IN_LIST_FIRST)
 		p.reader.ReadByte()
 		isNull = false
 	} else if p.safePeekContains(JSON_NULL) {
@@ -1036,7 +1079,7 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
 	if isNull, err := p.readIfNull(); isNull || err != nil {
 		return err
 	}
-	cxt := _ParseContext(p.parseContextStack[len(p.parseContextStack)-1])
+	cxt, _ := p.parseContextStack.peek()
 	if cxt != _CONTEXT_IN_LIST {
 		e := fmt.Errorf("Expected to be in the List Context, but not in List Context (%d)", cxt)
 		return NewTProtocolExceptionWithType(INVALID_DATA, e)
@@ -1054,8 +1097,10 @@ func (p *TSimpleJSONProtocol) ParseListEnd() error {
 			break
 		}
 	}
-	p.parseContextStack = p.parseContextStack[:len(p.parseContextStack)-1]
-	if _ParseContext(p.parseContextStack[len(p.parseContextStack)-1]) == _CONTEXT_IN_TOPLEVEL {
+	p.parseContextStack.pop()
+	if cxt, ok := p.parseContextStack.peek(); !ok {
+		return errEmptyJSONContextStack
+	} else if cxt == _CONTEXT_IN_TOPLEVEL {
 		return nil
 	}
 	return p.ParsePostValue()
@@ -1308,8 +1353,8 @@ func (p *TSimpleJSONProtocol) safePeekContains(b []byte) bool {
 
 // Reset the context stack to its initial state.
 func (p *TSimpleJSONProtocol) resetContextStack() {
-	p.parseContextStack = []int{int(_CONTEXT_IN_TOPLEVEL)}
-	p.dumpContext = []int{int(_CONTEXT_IN_TOPLEVEL)}
+	p.parseContextStack = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
+	p.dumpContext = jsonContextStack{_CONTEXT_IN_TOPLEVEL}
 }
 
 func (p *TSimpleJSONProtocol) write(b []byte) (int, error) {
diff --git a/lib/go/thrift/simple_json_protocol_test.go b/lib/go/thrift/simple_json_protocol_test.go
index 986fff2..89753c6 100644
--- a/lib/go/thrift/simple_json_protocol_test.go
+++ b/lib/go/thrift/simple_json_protocol_test.go
@@ -736,3 +736,58 @@ func TestWriteSimpleJSONProtocolSafePeek(t *testing.T) {
 		t.Fatalf("Should not match at test 3")
 	}
 }
+
+func TestJSONContextStack(t *testing.T) {
+	var stack jsonContextStack
+	t.Run("empty-peek", func(t *testing.T) {
+		v, ok := stack.peek()
+		if ok {
+			t.Error("peek() on empty should return ok: false")
+		}
+		expected := _CONTEXT_INVALID
+		if v != expected {
+			t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+		}
+	})
+	t.Run("empty-pop", func(t *testing.T) {
+		v, ok := stack.pop()
+		if ok {
+			t.Error("pop() on empty should return ok: false")
+		}
+		expected := _CONTEXT_INVALID
+		if v != expected {
+			t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+		}
+	})
+	t.Run("push-peek-pop", func(t *testing.T) {
+		expected := _CONTEXT_INVALID
+		stack.push(expected)
+		if len(stack) != 1 {
+			t.Errorf("Expected stack to be as size 1 after push, got %#v", stack)
+		}
+		v, ok := stack.peek()
+		if !ok {
+			t.Error("peek() on non-empty should return ok: true")
+		}
+		if v != expected {
+			t.Errorf("Expected value from peek() to be %v(%d), got %v(%d)", expected, expected, v, v)
+		}
+		if len(stack) != 1 {
+			t.Errorf("Expected peek() to be read-only, got %#v", stack)
+		}
+		v, ok = stack.pop()
+		if !ok {
+			t.Error("pop() on non-empty should return ok: true")
+		}
+		if v != expected {
+			t.Errorf("Expected value from pop() to be %v(%d), got %v(%d)", expected, expected, v, v)
+		}
+		if len(stack) != 0 {
+			t.Errorf("Expected pop() to empty the stack, got %#v", stack)
+		}
+	})
+}
+
+func TestTSimpleJSONProtocolUnmatchedBeginEnd(t *testing.T) {
+	UnmatchedBeginEndProtocolTest(t, NewTSimpleJSONProtocolFactory())
+}