You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2021/12/17 18:24:24 UTC

[thrift] branch master updated: THRIFT-5490: Use pooled buffer for THeaderTransport

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new d582a86  THRIFT-5490: Use pooled buffer for THeaderTransport
d582a86 is described below

commit d582a861426c43c869e71d8d6ce598a33cbab316
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Thu Dec 16 14:44:47 2021 -0800

    THRIFT-5490: Use pooled buffer for THeaderTransport
    
    Client: go
    
    Instead of binding 2 buffers (read/write) to each THeaderTransport, grab
    one from the pool to be used for the whole read/write, and return it
    back to the pool after the read/write is done. This would help reduce
    the memory footprint from idle connections.
---
 lib/go/thrift/buf_pool.go              | 52 ++++++++++++++++++++++++++++++
 lib/go/thrift/header_transport.go      | 42 +++++++++++++-----------
 lib/go/thrift/header_transport_test.go | 59 ++++++++++++++++++++++++++++++++--
 3 files changed, 133 insertions(+), 20 deletions(-)

diff --git a/lib/go/thrift/buf_pool.go b/lib/go/thrift/buf_pool.go
new file mode 100644
index 0000000..9708ea0
--- /dev/null
+++ b/lib/go/thrift/buf_pool.go
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+	"bytes"
+	"sync"
+)
+
+var bufPool = sync.Pool{
+	New: func() interface{} {
+		return new(bytes.Buffer)
+	},
+}
+
+// getBufFromPool gets a buffer out of the pool and guarantees that it's reset
+// before return.
+func getBufFromPool() *bytes.Buffer {
+	buf := bufPool.Get().(*bytes.Buffer)
+	buf.Reset()
+	return buf
+}
+
+// returnBufToPool returns a buffer to the pool, and sets it to nil to avoid
+// accidental usage after it's returned.
+//
+// You usually want to use it this way:
+//
+//     buf := getBufFromPool()
+//     defer returnBufToPool(&buf)
+//     // use buf
+func returnBufToPool(buf **bytes.Buffer) {
+	bufPool.Put(*buf)
+	*buf = nil
+}
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index f5736df..5ec0454 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -28,7 +28,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"io/ioutil"
 )
 
 // Size in bytes for 32-bit ints.
@@ -253,14 +252,14 @@ type THeaderTransport struct {
 	// Reading related variables.
 	reader *bufio.Reader
 	// When frame is detected, we read the frame fully into frameBuffer.
-	frameBuffer bytes.Buffer
+	frameBuffer *bytes.Buffer
 	// When it's non-nil, Read should read from frameReader instead of
 	// reader, and EOF error indicates end of frame instead of end of all
 	// transport.
 	frameReader io.ReadCloser
 
 	// Writing related variables
-	writeBuffer     bytes.Buffer
+	writeBuffer     *bytes.Buffer
 	writeTransforms []THeaderTransformID
 
 	clientType clientType
@@ -370,11 +369,14 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
 	t.reader.Discard(size32)
 
 	// Read the frame fully into frameBuffer.
-	_, err = io.CopyN(&t.frameBuffer, t.reader, int64(frameSize))
+	if t.frameBuffer == nil {
+		t.frameBuffer = getBufFromPool()
+	}
+	_, err = io.CopyN(t.frameBuffer, t.reader, int64(frameSize))
 	if err != nil {
 		return err
 	}
-	t.frameReader = ioutil.NopCloser(&t.frameBuffer)
+	t.frameReader = io.NopCloser(t.frameBuffer)
 
 	// Peek and handle the next 32 bits.
 	buf = t.frameBuffer.Bytes()[:size32]
@@ -405,7 +407,7 @@ func (t *THeaderTransport) ReadFrame(ctx context.Context) error {
 // It closes frameReader, and also resets frame related states.
 func (t *THeaderTransport) endOfFrame() error {
 	defer func() {
-		t.frameBuffer.Reset()
+		returnBufToPool(&t.frameBuffer)
 		t.frameReader = nil
 	}()
 	return t.frameReader.Close()
@@ -418,7 +420,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
 
 	var err error
 	var meta headerMeta
-	if err = binary.Read(&t.frameBuffer, binary.BigEndian, &meta); err != nil {
+	if err = binary.Read(t.frameBuffer, binary.BigEndian, &meta); err != nil {
 		return err
 	}
 	frameSize -= headerMetaSize
@@ -432,7 +434,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
 		)
 	}
 	headerBuf := NewTMemoryBuffer()
-	_, err = io.CopyN(headerBuf, &t.frameBuffer, headerLength)
+	_, err = io.CopyN(headerBuf, t.frameBuffer, headerLength)
 	if err != nil {
 		return err
 	}
@@ -454,7 +456,7 @@ func (t *THeaderTransport) parseHeaders(ctx context.Context, frameSize uint32) e
 	}
 	if transformCount > 0 {
 		reader := NewTransformReaderWithCapacity(
-			&t.frameBuffer,
+			t.frameBuffer,
 			int(transformCount),
 		)
 		t.frameReader = reader
@@ -569,16 +571,19 @@ func (t *THeaderTransport) Read(p []byte) (read int, err error) {
 //
 // You need to call Flush to actually write them to the transport.
 func (t *THeaderTransport) Write(p []byte) (int, error) {
+	if t.writeBuffer == nil {
+		t.writeBuffer = getBufFromPool()
+	}
 	return t.writeBuffer.Write(p)
 }
 
 // Flush writes the appropriate header and the write buffer to the underlying transport.
 func (t *THeaderTransport) Flush(ctx context.Context) error {
-	if t.writeBuffer.Len() == 0 {
+	if t.writeBuffer == nil || t.writeBuffer.Len() == 0 {
 		return nil
 	}
 
-	defer t.writeBuffer.Reset()
+	defer returnBufToPool(&t.writeBuffer)
 
 	switch t.clientType {
 	default:
@@ -628,24 +633,25 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
 			}
 		}
 
-		var payload bytes.Buffer
+		payload := getBufFromPool()
+		defer returnBufToPool(&payload)
 		meta := headerMeta{
 			MagicFlags:   THeaderHeaderMagic + t.Flags&THeaderFlagsMask,
 			SequenceID:   t.SequenceID,
 			HeaderLength: uint16(headers.Len() / 4),
 		}
-		if err := binary.Write(&payload, binary.BigEndian, meta); err != nil {
+		if err := binary.Write(payload, binary.BigEndian, meta); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
-		if _, err := io.Copy(&payload, headers); err != nil {
+		if _, err := io.Copy(payload, headers); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
 
-		writer, err := NewTransformWriter(&payload, t.writeTransforms)
+		writer, err := NewTransformWriter(payload, t.writeTransforms)
 		if err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
-		if _, err := io.Copy(writer, &t.writeBuffer); err != nil {
+		if _, err := io.Copy(writer, t.writeBuffer); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
 		if err := writer.Close(); err != nil {
@@ -659,7 +665,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
 			return NewTTransportExceptionFromError(err)
 		}
 		// Then write the payload
-		if _, err := io.Copy(t.transport, &payload); err != nil {
+		if _, err := io.Copy(t.transport, payload); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
 
@@ -671,7 +677,7 @@ func (t *THeaderTransport) Flush(ctx context.Context) error {
 		}
 		fallthrough
 	case clientUnframedBinary, clientUnframedCompact:
-		if _, err := io.Copy(t.transport, &t.writeBuffer); err != nil {
+		if _, err := io.Copy(t.transport, t.writeBuffer); err != nil {
 			return NewTTransportExceptionFromError(err)
 		}
 	}
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 65e69ee..25ba8d3 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -23,7 +23,6 @@ import (
 	"context"
 	"fmt"
 	"io"
-	"io/ioutil"
 	"strings"
 	"testing"
 	"testing/quick"
@@ -87,7 +86,7 @@ func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocol
 	if err := reader.ReadFrame(context.Background()); err != nil {
 		t.Errorf("reader.ReadFrame returned error: %v", err)
 	}
-	read, err := ioutil.ReadAll(reader)
+	read, err := io.ReadAll(reader)
 	if err != nil {
 		t.Errorf("Read returned error: %v", err)
 	}
@@ -305,3 +304,59 @@ func TestSetTHeaderTransportProtocolID(t *testing.T) {
 		t.Errorf("Expected protocol id %v, got %v", expected, actual)
 	}
 }
+
+func TestTHeaderTransportReuseTransport(t *testing.T) {
+	const (
+		content = "Hello, world!"
+		n       = 10
+	)
+	trans := NewTMemoryBuffer()
+	reader := NewTHeaderTransport(trans)
+	writer := NewTHeaderTransport(trans)
+
+	t.Run("pair", func(t *testing.T) {
+		for i := 0; i < n; i++ {
+			// write
+			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+				t.Fatalf("Failed to write on #%d: %v", i, err)
+			}
+			if err := writer.Flush(context.Background()); err != nil {
+				t.Fatalf("Failed to flush on #%d: %v", i, err)
+			}
+
+			// read
+			read, err := io.ReadAll(reader)
+			if err != nil {
+				t.Errorf("Failed to read on #%d: %v", i, err)
+			}
+			if string(read) != content {
+				t.Errorf("Read #%d: want %q, got %q", i, content, read)
+			}
+		}
+	})
+
+	t.Run("batched", func(t *testing.T) {
+		// write
+		for i := 0; i < n; i++ {
+			if _, err := io.Copy(writer, strings.NewReader(content)); err != nil {
+				t.Fatalf("Failed to write on #%d: %v", i, err)
+			}
+			if err := writer.Flush(context.Background()); err != nil {
+				t.Fatalf("Failed to flush on #%d: %v", i, err)
+			}
+		}
+
+		// read
+		for i := 0; i < n; i++ {
+			buf := make([]byte, len(content))
+			n, err := reader.Read(buf)
+			if err != nil {
+				t.Errorf("Failed to read on #%d: %v", i, err)
+			}
+			read := string(buf[:n])
+			if string(read) != content {
+				t.Errorf("Read #%d: want %q, got %q", i, content, read)
+			}
+		}
+	})
+}