You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2020/12/16 17:33:50 UTC

[thrift] branch master updated: THRIFT-5322: Guard against large string/binary lengths in Go

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 37c2ceb  THRIFT-5322: Guard against large string/binary lengths in Go
37c2ceb is described below

commit 37c2ceb737cb40377346c63a05f407da1c119ba0
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Thu Dec 10 14:42:37 2020 -0800

    THRIFT-5322: Guard against large string/binary lengths in Go
    
    Client: go
    
    In TBinaryProtocol.ReadString, TBinaryProtocol.ReadBinary,
    TCompactProtocol.ReadString, and TCompactProtocol.ReadBinary, use
    safeReadBytes to prevent from large allocation on malformed sizes.
    
        $ go test -bench=SafeReadBytes -benchmem
        BenchmarkSafeReadBytes/normal-12                  625057              1789 ns/op            2176 B/op          5 allocs/op
        BenchmarkSafeReadBytes/max-askedSize-12           545271              2236 ns/op           14464 B/op          7 allocs/op
        PASS
---
 lib/go/thrift/binary_protocol.go      |  56 +++++--------
 lib/go/thrift/binary_protocol_test.go | 153 ++++++++++++++++++++++++++++++++++
 lib/go/thrift/compact_protocol.go     |  17 ++--
 3 files changed, 184 insertions(+), 42 deletions(-)

diff --git a/lib/go/thrift/binary_protocol.go b/lib/go/thrift/binary_protocol.go
index c87d23a..58956f6 100644
--- a/lib/go/thrift/binary_protocol.go
+++ b/lib/go/thrift/binary_protocol.go
@@ -432,6 +432,15 @@ func (p *TBinaryProtocol) ReadString(ctx context.Context) (value string, err err
 		err = invalidDataLength
 		return
 	}
+	if size == 0 {
+		return "", nil
+	}
+	if size < int32(len(p.buffer)) {
+		// Avoid allocation on small reads
+		buf := p.buffer[:size]
+		read, e := io.ReadFull(p.trans, buf)
+		return string(buf[:read]), NewTProtocolException(e)
+	}
 
 	return p.readStringBody(size)
 }
@@ -445,9 +454,7 @@ func (p *TBinaryProtocol) ReadBinary(ctx context.Context) ([]byte, error) {
 		return nil, invalidDataLength
 	}
 
-	isize := int(size)
-	buf := make([]byte, isize)
-	_, err := io.ReadFull(p.trans, buf)
+	buf, err := safeReadBytes(size, p.trans)
 	return buf, NewTProtocolException(err)
 }
 
@@ -479,38 +486,21 @@ func (p *TBinaryProtocol) readAll(ctx context.Context, buf []byte) (err error) {
 	return NewTProtocolException(err)
 }
 
-const readLimit = 32768
-
 func (p *TBinaryProtocol) readStringBody(size int32) (value string, err error) {
-	if size < 0 {
-		return "", nil
-	}
-
-	var (
-		buf bytes.Buffer
-		e   error
-		b   []byte
-	)
+	buf, err := safeReadBytes(size, p.trans)
+	return string(buf), NewTProtocolException(err)
+}
 
-	switch {
-	case int(size) <= len(p.buffer):
-		b = p.buffer[:size] // avoids allocation for small reads
-	case int(size) < readLimit:
-		b = make([]byte, size)
-	default:
-		b = make([]byte, readLimit)
+// This function is shared between TBinaryProtocol and TCompactProtocol.
+//
+// It tries to read size bytes from trans, in a way that prevents large
+// allocations when size is insanely large (mostly caused by malformed message).
+func safeReadBytes(size int32, trans io.Reader) ([]byte, error) {
+	if size < 0 {
+		return nil, nil
 	}
 
-	for size > 0 {
-		_, e = io.ReadFull(p.trans, b)
-		buf.Write(b)
-		if e != nil {
-			break
-		}
-		size -= readLimit
-		if size < readLimit && size > 0 {
-			b = b[:size]
-		}
-	}
-	return buf.String(), NewTProtocolException(e)
+	buf := new(bytes.Buffer)
+	_, err := io.CopyN(buf, trans, int64(size))
+	return buf.Bytes(), err
 }
diff --git a/lib/go/thrift/binary_protocol_test.go b/lib/go/thrift/binary_protocol_test.go
index 0462cc7..88bfd26 100644
--- a/lib/go/thrift/binary_protocol_test.go
+++ b/lib/go/thrift/binary_protocol_test.go
@@ -20,9 +20,162 @@
 package thrift
 
 import (
+	"bytes"
+	"math"
+	"strings"
 	"testing"
 )
 
 func TestReadWriteBinaryProtocol(t *testing.T) {
 	ReadWriteProtocolTest(t, NewTBinaryProtocolFactoryDefault())
 }
+
+const (
+	safeReadBytesSource = `
+Lorem ipsum dolor sit amet, consectetur adipiscing elit. Integer sit amet
+tincidunt nibh. Phasellus vel convallis libero, sit amet posuere quam. Nullam
+blandit velit at nibh fringilla, sed egestas erat dapibus. Sed hendrerit
+tincidunt accumsan. Curabitur consectetur bibendum dui nec hendrerit. Fusce quis
+turpis nec magna efficitur volutpat a ut nibh. Vestibulum odio risus, tristique
+a nisi et, congue mattis mi. Vivamus a nunc justo. Mauris molestie sagittis
+magna, hendrerit auctor lectus egestas non. Phasellus pretium, odio sit amet
+bibendum feugiat, velit nunc luctus erat, ac bibendum mi dui molestie nulla.
+Nullam fermentum magna eu elit vehicula tincidunt. Etiam ornare laoreet
+dignissim. Ut sed nunc ac neque vulputate fermentum. Morbi volutpat dapibus
+magna, at porttitor quam facilisis a. Donec eget fermentum risus. Aliquam erat
+volutpat.
+
+Phasellus molestie id ante vel iaculis. Fusce eget quam nec quam viverra laoreet
+vitae a dui. Mauris blandit blandit dui, iaculis interdum diam mollis at. Morbi
+vel sem et.
+`
+	safeReadBytesSourceLen = len(safeReadBytesSource)
+)
+
+func TestSafeReadBytes(t *testing.T) {
+	srcData := []byte(safeReadBytesSource)
+
+	for _, c := range []struct {
+		label     string
+		askedSize int32
+		dataSize  int
+	}{
+		{
+			label:     "normal",
+			askedSize: 100,
+			dataSize:  100,
+		},
+		{
+			label:     "max-askedSize",
+			askedSize: math.MaxInt32,
+			dataSize:  safeReadBytesSourceLen,
+		},
+	} {
+		t.Run(c.label, func(t *testing.T) {
+			data := bytes.NewReader(srcData[:c.dataSize])
+			buf, err := safeReadBytes(c.askedSize, data)
+			if len(buf) != c.dataSize {
+				t.Errorf(
+					"Expected to read %d bytes, got %d",
+					c.dataSize,
+					len(buf),
+				)
+			}
+			if !strings.HasPrefix(safeReadBytesSource, string(buf)) {
+				t.Errorf("Unexpected read data: %q", buf)
+			}
+			if int32(c.dataSize) < c.askedSize {
+				// We expect error in this case
+				if err == nil {
+					t.Errorf(
+						"Expected error when dataSize %d < askedSize %d, got nil",
+						c.dataSize,
+						c.askedSize,
+					)
+				}
+			} else {
+				// We expect no error in this case
+				if err != nil {
+					t.Errorf(
+						"Expected no error when dataSize %d >= askedSize %d, got: %v",
+						c.dataSize,
+						c.askedSize,
+						err,
+					)
+				}
+			}
+		})
+	}
+}
+
+func generateSafeReadBytesBenchmark(askedSize int32, dataSize int) func(b *testing.B) {
+	return func(b *testing.B) {
+		data := make([]byte, dataSize)
+		b.ResetTimer()
+		for i := 0; i < b.N; i++ {
+			safeReadBytes(askedSize, bytes.NewReader(data))
+		}
+	}
+}
+
+func TestSafeReadBytesAlloc(t *testing.T) {
+	if testing.Short() {
+		// NOTE: Since this test runs a benchmark test, it takes at
+		// least 1 second.
+		//
+		// In general we try to avoid unit tests taking that long to run,
+		// but it's to verify a security issue so we made an exception
+		// here:
+		// https://issues.apache.org/jira/browse/THRIFT-5322
+		t.Skip("skipping test in short mode.")
+	}
+
+	const (
+		askedSize = int32(math.MaxInt32)
+		dataSize  = 4096
+	)
+
+	// The purpose of this test is that in the case a string header says
+	// that it has a string askedSize bytes long, the implementation should
+	// not just allocate askedSize bytes upfront. So when there're actually
+	// not enough data to be read (dataSize), the actual allocated bytes
+	// should be somewhere between dataSize and askedSize.
+	//
+	// Different approachs could have different memory overheads, so this
+	// target is arbitrary in nature. But when dataSize is small enough
+	// compare to askedSize, half the askedSize is a good and safe target.
+	const target = int64(askedSize) / 2
+
+	bm := testing.Benchmark(generateSafeReadBytesBenchmark(askedSize, dataSize))
+	actual := bm.AllocedBytesPerOp()
+	if actual > target {
+		t.Errorf(
+			"Expected allocated bytes per op to be <= %d, got %d",
+			target,
+			actual,
+		)
+	} else {
+		t.Logf("Allocated bytes: %d B/op", actual)
+	}
+}
+
+func BenchmarkSafeReadBytes(b *testing.B) {
+	for _, c := range []struct {
+		label     string
+		askedSize int32
+		dataSize  int
+	}{
+		{
+			label:     "normal",
+			askedSize: 100,
+			dataSize:  100,
+		},
+		{
+			label:     "max-askedSize",
+			askedSize: math.MaxInt32,
+			dataSize:  4096,
+		},
+	} {
+		b.Run(c.label, generateSafeReadBytesBenchmark(c.askedSize, c.dataSize))
+	}
+}
diff --git a/lib/go/thrift/compact_protocol.go b/lib/go/thrift/compact_protocol.go
index a016195..424906d 100644
--- a/lib/go/thrift/compact_protocol.go
+++ b/lib/go/thrift/compact_protocol.go
@@ -579,17 +579,17 @@ func (p *TCompactProtocol) ReadString(ctx context.Context) (value string, err er
 	if length < 0 {
 		return "", invalidDataLength
 	}
-
 	if length == 0 {
 		return "", nil
 	}
-	var buf []byte
-	if length <= int32(len(p.buffer)) {
-		buf = p.buffer[0:length]
-	} else {
-		buf = make([]byte, length)
+	if length < int32(len(p.buffer)) {
+		// Avoid allocation on small reads
+		buf := p.buffer[:length]
+		read, e := io.ReadFull(p.trans, buf)
+		return string(buf[:read]), NewTProtocolException(e)
 	}
-	_, e = io.ReadFull(p.trans, buf)
+
+	buf, e := safeReadBytes(length, p.trans)
 	return string(buf), NewTProtocolException(e)
 }
 
@@ -606,8 +606,7 @@ func (p *TCompactProtocol) ReadBinary(ctx context.Context) (value []byte, err er
 		return nil, invalidDataLength
 	}
 
-	buf := make([]byte, length)
-	_, e = io.ReadFull(p.trans, buf)
+	buf, e := safeReadBytes(length, p.trans)
 	return buf, NewTProtocolException(e)
 }