You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2022/01/05 22:22:06 UTC

[thrift] branch master updated: THRIFT-5490: Use pooled buffer for TFramedTransport

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 999e6e3  THRIFT-5490: Use pooled buffer for TFramedTransport
999e6e3 is described below

commit 999e6e3bce217acb35b44440fd656cf169d47ed8
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Fri Dec 17 10:39:07 2021 -0800

    THRIFT-5490: Use pooled buffer for TFramedTransport
    
    Client: go
    
    Follow up on d582a8614, do the same thing on TFramedTransport.
    
    Also update the test on the implementation of THeaderTransport to make
    sure that small reads are not broken.
---
 lib/go/thrift/framed_transport.go      | 61 ++++++++++++++++++++--------
 lib/go/thrift/framed_transport_test.go | 73 ++++++++++++++++++++++++++++++++++
 lib/go/thrift/header_transport_test.go | 39 +++++++++++++++---
 3 files changed, 150 insertions(+), 23 deletions(-)

diff --git a/lib/go/thrift/framed_transport.go b/lib/go/thrift/framed_transport.go
index 2156dd7..c8bd35e 100644
--- a/lib/go/thrift/framed_transport.go
+++ b/lib/go/thrift/framed_transport.go
@@ -36,10 +36,10 @@ type TFramedTransport struct {
 
 	cfg *TConfiguration
 
-	writeBuf bytes.Buffer
+	writeBuf *bytes.Buffer
 
 	reader  *bufio.Reader
-	readBuf bytes.Buffer
+	readBuf *bytes.Buffer
 
 	buffer [4]byte
 }
@@ -129,18 +129,29 @@ func (p *TFramedTransport) Close() error {
 }
 
 func (p *TFramedTransport) Read(buf []byte) (read int, err error) {
-	read, err = p.readBuf.Read(buf)
-	if err != io.EOF {
-		return
-	}
+	defer func() {
+		// Make sure we return the read buffer back to pool
+		// after we finished reading from it.
+		if p.readBuf != nil && p.readBuf.Len() == 0 {
+			returnBufToPool(&p.readBuf)
+		}
+	}()
+
+	if p.readBuf != nil {
 
-	// For bytes.Buffer.Read, EOF would only happen when read is zero,
-	// but still, do a sanity check,
-	// in case that behavior is changed in a future version of go stdlib.
-	// When that happens, just return nil error,
-	// and let the caller call Read again to read the next frame.
-	if read > 0 {
-		return read, nil
+		read, err = p.readBuf.Read(buf)
+		if err != io.EOF {
+			return
+		}
+
+		// For bytes.Buffer.Read, EOF would only happen when read is zero,
+		// but still, do a sanity check,
+		// in case that behavior is changed in a future version of go stdlib.
+		// When that happens, just return nil error,
+		// and let the caller call Read again to read the next frame.
+		if read > 0 {
+			return read, nil
+		}
 	}
 
 	// Reaching here means that the last Read finished the last frame,
@@ -162,31 +173,39 @@ func (p *TFramedTransport) ReadByte() (c byte, err error) {
 	return
 }
 
+func (p *TFramedTransport) ensureWriteBufferBeforeWrite() {
+	if p.writeBuf == nil {
+		p.writeBuf = getBufFromPool()
+	}
+}
+
 func (p *TFramedTransport) Write(buf []byte) (int, error) {
+	p.ensureWriteBufferBeforeWrite()
 	n, err := p.writeBuf.Write(buf)
 	return n, NewTTransportExceptionFromError(err)
 }
 
 func (p *TFramedTransport) WriteByte(c byte) error {
+	p.ensureWriteBufferBeforeWrite()
 	return p.writeBuf.WriteByte(c)
 }
 
 func (p *TFramedTransport) WriteString(s string) (n int, err error) {
+	p.ensureWriteBufferBeforeWrite()
 	return p.writeBuf.WriteString(s)
 }
 
 func (p *TFramedTransport) Flush(ctx context.Context) error {
+	defer returnBufToPool(&p.writeBuf)
 	size := p.writeBuf.Len()
 	buf := p.buffer[:4]
 	binary.BigEndian.PutUint32(buf, uint32(size))
 	_, err := p.transport.Write(buf)
 	if err != nil {
-		p.writeBuf.Reset()
 		return NewTTransportExceptionFromError(err)
 	}
 	if size > 0 {
-		if _, err := io.Copy(p.transport, &p.writeBuf); err != nil {
-			p.writeBuf.Reset()
+		if _, err := io.Copy(p.transport, p.writeBuf); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
 	}
@@ -195,6 +214,11 @@ func (p *TFramedTransport) Flush(ctx context.Context) error {
 }
 
 func (p *TFramedTransport) readFrame() error {
+	if p.readBuf != nil {
+		returnBufToPool(&p.readBuf)
+	}
+	p.readBuf = getBufFromPool()
+
 	buf := p.buffer[:4]
 	if _, err := io.ReadFull(p.reader, buf); err != nil {
 		return err
@@ -203,11 +227,14 @@ func (p *TFramedTransport) readFrame() error {
 	if size > uint32(p.cfg.GetMaxFrameSize()) {
 		return NewTTransportException(UNKNOWN_TRANSPORT_EXCEPTION, fmt.Sprintf("Incorrect frame size (%d)", size))
 	}
-	_, err := io.CopyN(&p.readBuf, p.reader, int64(size))
+	_, err := io.CopyN(p.readBuf, p.reader, int64(size))
 	return NewTTransportExceptionFromError(err)
 }
 
 func (p *TFramedTransport) RemainingBytes() (num_bytes uint64) {
+	if p.readBuf == nil {
+		return 0
+	}
 	return uint64(p.readBuf.Len())
 }
 
diff --git a/lib/go/thrift/framed_transport_test.go b/lib/go/thrift/framed_transport_test.go
index 8f683ef..4e7d9ca 100644
--- a/lib/go/thrift/framed_transport_test.go
+++ b/lib/go/thrift/framed_transport_test.go
@@ -20,6 +20,9 @@
 package thrift
 
 import (
+	"context"
+	"io"
+	"strings"
 	"testing"
 )
 
@@ -27,3 +30,73 @@ func TestFramedTransport(t *testing.T) {
 	trans := NewTFramedTransport(NewTMemoryBuffer())
 	TransportTest(t, trans, trans)
 }
+
+func TestTFramedTransportReuseTransport(t *testing.T) {
+	const (
+		content = "Hello, world!"
+		n       = 10
+	)
+	trans := NewTMemoryBuffer()
+	reader := NewTFramedTransport(trans)
+	writer := NewTFramedTransport(trans)
+
+	t.Run("pair", func(t *testing.T) {
+		for i := 0; i < n; i++ {
+			// write
+			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+				t.Fatalf("Failed to write on #%d: %v", i, err)
+			}
+			if err := writer.Flush(context.Background()); err != nil {
+				t.Fatalf("Failed to flush on #%d: %v", i, err)
+			}
+
+			// read
+			read, err := io.ReadAll(oneAtATimeReader{reader})
+			if err != nil {
+				t.Errorf("Failed to read on #%d: %v", i, err)
+			}
+			if string(read) != content {
+				t.Errorf("Read #%d: want %q, got %q", i, content, read)
+			}
+		}
+	})
+
+	t.Run("batched", func(t *testing.T) {
+		// write
+		for i := 0; i < n; i++ {
+			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+				t.Fatalf("Failed to write on #%d: %v", i, err)
+			}
+			if err := writer.Flush(context.Background()); err != nil {
+				t.Fatalf("Failed to flush on #%d: %v", i, err)
+			}
+		}
+
+		// read
+		for i := 0; i < n; i++ {
+			const (
+				size = len(content)
+			)
+			var buf []byte
+			var err error
+			if i%2 == 0 {
+				// on even calls, use oneAtATimeReader to make
+				// sure that small reads are fine
+				buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+			} else {
+				// on odd calls, make sure that we don't read
+				// more than written per frame
+				buf = make([]byte, size*2)
+				var n int
+				n, err = reader.Read(buf)
+				buf = buf[:n]
+			}
+			if err != nil {
+				t.Errorf("Failed to read on #%d: %v", i, err)
+			}
+			if string(buf) != content {
+				t.Errorf("Read #%d: want %q, got %q", i, content, buf)
+			}
+		}
+	})
+}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 25ba8d3..44d0284 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -325,7 +325,7 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
 			}
 
 			// read
-			read, err := io.ReadAll(reader)
+			read, err := io.ReadAll(oneAtATimeReader{reader})
 			if err != nil {
 				t.Errorf("Failed to read on #%d: %v", i, err)
 			}
@@ -348,15 +348,42 @@ func TestTHeaderTransportReuseTransport(t *testing.T) {
 
 		// read
 		for i := 0; i < n; i++ {
-			buf := make([]byte, len(content))
-			n, err := reader.Read(buf)
+			const (
+				size = len(content)
+			)
+			var buf []byte
+			var err error
+			if i%2 == 0 {
+				// on even calls, use oneAtATimeReader to make
+				// sure that small reads are fine
+				buf, err = io.ReadAll(io.LimitReader(oneAtATimeReader{reader}, int64(size)))
+			} else {
+				// on odd calls, make sure that we don't read
+				// more than written per frame
+				buf = make([]byte, size*2)
+				var n int
+				n, err = reader.Read(buf)
+				buf = buf[:n]
+			}
 			if err != nil {
 				t.Errorf("Failed to read on #%d: %v", i, err)
 			}
-			read := string(buf[:n])
-			if string(read) != content {
-				t.Errorf("Read #%d: want %q, got %q", i, content, read)
+			if string(buf) != content {
+				t.Errorf("Read #%d: want %q, got %q", i, content, buf)
 			}
 		}
 	})
 }
+
+type oneAtATimeReader struct {
+	io.Reader
+}
+
+// oneAtATimeReader forces every Read call to only read 1 byte out,
+// thus forces the underlying reader's Read to be called multiple times.
+func (o oneAtATimeReader) Read(buf []byte) (int, error) {
+	if len(buf) < 1 {
+		return o.Reader.Read(buf)
+	}
+	return o.Reader.Read(buf[:1])
+}