You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by dc...@apache.org on 2020/01/18 20:56:06 UTC

[thrift] branch master updated: THRIFT-5069: Make TDeserializer resource pool friendly

This is an automated email from the ASF dual-hosted git repository.

dcelasun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new 397645a  THRIFT-5069: Make TDeserializer resource pool friendly
397645a is described below

commit 397645ac24874b6f54d88b2700e56be090753825
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Sat Jan 18 12:55:51 2020 -0800

    THRIFT-5069: Make TDeserializer resource pool friendly
    
    Client: go
    
    This change improves performance when using TDeserializer with a
    resource pool. See https://issues.apache.org/jira/browse/THRIFT-5069 for
    more context.
    
    Also add TSerializerPool and TDeserializerPool, which are thread-safe
    versions of TSerializer and TDeserializer. Benchmark result shows that
    they are both faster and use less memory than the plain version:
    
        $ go test -bench Serializer -benchmem
        goos: darwin
        goarch: amd64
        BenchmarkSerializer/baseline-8            577558              1930 ns/op             512 B/op          6 allocs/op
        BenchmarkSerializer/plain-8               452712              2638 ns/op            2976 B/op         16 allocs/op
        BenchmarkSerializer/pool-8                591698              2032 ns/op             512 B/op          6 allocs/op
        PASS
---
 CHANGES.md                       |   6 +
 lib/go/thrift/deserializer.go    |  46 +++++-
 lib/go/thrift/serializer.go      |  34 ++++
 lib/go/thrift/serializer_test.go | 325 ++++++++++++++++++++++++++++++---------
 4 files changed, 332 insertions(+), 79 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index e179a63..1dddab9 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -7,10 +7,16 @@
 - [THRIFT-4990](https://issues.apache.org/jira/browse/THRIFT-4990) - Upgrade to .NET Core 3.1 (LTS)
 - [THRIFT-4981](https://issues.apache.org/jira/browse/THRIFT-4981) - Remove deprecated netcore bindings from the code base
 - [THRIFT-5006](https://issues.apache.org/jira/browse/THRIFT-5006) - Implement DEFAULT_MAX_LENGTH at TFramedTransport
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - In Go library TDeserializer.Transport is now typed \*TMemoryBuffer instead of TTransport
 
 ### Java
 
 - [THRIFT-5022](https://issues.apache.org/jira/browse/THRIFT-5022) - TIOStreamTransport.isOpen returns true for one-sided transports (see THRIFT-2530).
+
+### Go
+
+- [THRIFT-5069](https://issues.apache.org/jira/browse/THRIFT-5069) - Add TSerializerPool and TDeserializerPool, which are thread-safe versions of TSerializer and TDeserializer.
+
 ## 0.13.0
 
 ### New Languages
diff --git a/lib/go/thrift/deserializer.go b/lib/go/thrift/deserializer.go
index 91a0983..2ab8214 100644
--- a/lib/go/thrift/deserializer.go
+++ b/lib/go/thrift/deserializer.go
@@ -19,14 +19,17 @@
 
 package thrift
 
+import (
+	"sync"
+)
+
 type TDeserializer struct {
-	Transport TTransport
+	Transport *TMemoryBuffer
 	Protocol  TProtocol
 }
 
 func NewTDeserializer() *TDeserializer {
-	var transport TTransport
-	transport = NewTMemoryBufferLen(1024)
+	transport := NewTMemoryBufferLen(1024)
 
 	protocol := NewTBinaryProtocolFactoryDefault().GetProtocol(transport)
 
@@ -36,6 +39,8 @@ func NewTDeserializer() *TDeserializer {
 }
 
 func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
+	t.Transport.Reset()
+
 	err = nil
 	if _, err = t.Transport.Write([]byte(s)); err != nil {
 		return
@@ -47,6 +52,8 @@ func (t *TDeserializer) ReadString(msg TStruct, s string) (err error) {
 }
 
 func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
+	t.Transport.Reset()
+
 	err = nil
 	if _, err = t.Transport.Write(b); err != nil {
 		return
@@ -56,3 +63,36 @@ func (t *TDeserializer) Read(msg TStruct, b []byte) (err error) {
 	}
 	return
 }
+
+// TDeserializerPool is the thread-safe version of TDeserializer,
+// it uses resource pool of TDeserializer under the hood.
+//
+// It must be initialized with NewTDeserializerPool.
+type TDeserializerPool struct {
+	pool sync.Pool
+}
+
+// NewTDeserializerPool creates a new TDeserializerPool.
+//
+// NewTDeserializer can be used as the arg here.
+func NewTDeserializerPool(f func() *TDeserializer) *TDeserializerPool {
+	return &TDeserializerPool{
+		pool: sync.Pool{
+			New: func() interface{} {
+				return f()
+			},
+		},
+	}
+}
+
+func (t *TDeserializerPool) ReadString(msg TStruct, s string) error {
+	d := t.pool.Get().(*TDeserializer)
+	defer t.pool.Put(d)
+	return d.ReadString(msg, s)
+}
+
+func (t *TDeserializerPool) Read(msg TStruct, b []byte) error {
+	d := t.pool.Get().(*TDeserializer)
+	defer t.pool.Put(d)
+	return d.Read(msg, b)
+}
diff --git a/lib/go/thrift/serializer.go b/lib/go/thrift/serializer.go
index 1ff4d37..d85d204 100644
--- a/lib/go/thrift/serializer.go
+++ b/lib/go/thrift/serializer.go
@@ -21,6 +21,7 @@ package thrift
 
 import (
 	"context"
+	"sync"
 )
 
 type TSerializer struct {
@@ -77,3 +78,36 @@ func (t *TSerializer) Write(ctx context.Context, msg TStruct) (b []byte, err err
 	b = append(b, t.Transport.Bytes()...)
 	return
 }
+
+// TSerializerPool is the thread-safe version of TSerializer, it uses resource
+// pool of TSerializer under the hood.
+//
+// It must be initialized with NewTSerializerPool.
+type TSerializerPool struct {
+	pool sync.Pool
+}
+
+// NewTSerializerPool creates a new TSerializerPool.
+//
+// NewTSerializer can be used as the arg here.
+func NewTSerializerPool(f func() *TSerializer) *TSerializerPool {
+	return &TSerializerPool{
+		pool: sync.Pool{
+			New: func() interface{} {
+				return f()
+			},
+		},
+	}
+}
+
+func (t *TSerializerPool) WriteString(ctx context.Context, msg TStruct) (string, error) {
+	s := t.pool.Get().(*TSerializer)
+	defer t.pool.Put(s)
+	return s.WriteString(ctx, msg)
+}
+
+func (t *TSerializerPool) Write(ctx context.Context, msg TStruct) ([]byte, error) {
+	s := t.pool.Get().(*TSerializer)
+	defer t.pool.Put(s)
+	return s.Write(ctx, msg)
+}
diff --git a/lib/go/thrift/serializer_test.go b/lib/go/thrift/serializer_test.go
index 32227ef..52ebdca 100644
--- a/lib/go/thrift/serializer_test.go
+++ b/lib/go/thrift/serializer_test.go
@@ -23,122 +23,193 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"sync"
+	"sync/atomic"
 	"testing"
+	"testing/quick"
 )
 
 type ProtocolFactory interface {
 	GetProtocol(t TTransport) TProtocol
 }
 
-func compareStructs(m, m1 MyTestStruct) (bool, error) {
+func compareStructs(m, m1 MyTestStruct) error {
 	switch {
 	case m.On != m1.On:
-		return false, errors.New("Boolean not equal")
+		return errors.New("Boolean not equal")
 	case m.B != m1.B:
-		return false, errors.New("Byte not equal")
+		return errors.New("Byte not equal")
 	case m.Int16 != m1.Int16:
-		return false, errors.New("Int16 not equal")
+		return errors.New("Int16 not equal")
 	case m.Int32 != m1.Int32:
-		return false, errors.New("Int32 not equal")
+		return errors.New("Int32 not equal")
 	case m.Int64 != m1.Int64:
-		return false, errors.New("Int64 not equal")
+		return errors.New("Int64 not equal")
 	case m.D != m1.D:
-		return false, errors.New("Double not equal")
+		return errors.New("Double not equal")
 	case m.St != m1.St:
-		return false, errors.New("String not equal")
+		return errors.New("String not equal")
 
 	case len(m.Bin) != len(m1.Bin):
-		return false, errors.New("Binary size not equal")
+		return errors.New("Binary size not equal")
 	case len(m.Bin) == len(m1.Bin):
 		for i := range m.Bin {
 			if m.Bin[i] != m1.Bin[i] {
-				return false, errors.New("Binary not equal")
+				return errors.New("Binary not equal")
 			}
 		}
 	case len(m.StringMap) != len(m1.StringMap):
-		return false, errors.New("StringMap size not equal")
+		return errors.New("StringMap size not equal")
 	case len(m.StringList) != len(m1.StringList):
-		return false, errors.New("StringList size not equal")
+		return errors.New("StringList size not equal")
 	case len(m.StringSet) != len(m1.StringSet):
-		return false, errors.New("StringSet size not equal")
+		return errors.New("StringSet size not equal")
 
 	case m.E != m1.E:
-		return false, errors.New("MyTestEnum not equal")
+		return errors.New("MyTestEnum not equal")
 
 	default:
-		return true, nil
+		return nil
 
 	}
-	return true, nil
+	return nil
 }
 
-func ProtocolTest1(test *testing.T, pf ProtocolFactory) (bool, error) {
+type serializer interface {
+	WriteString(context.Context, TStruct) (string, error)
+}
+
+type deserializer interface {
+	ReadString(TStruct, string) error
+}
+
+func plainSerializer(pf ProtocolFactory) serializer {
 	t := NewTSerializer()
 	t.Protocol = pf.GetProtocol(t.Transport)
-	var m = MyTestStruct{}
-	m.On = true
-	m.B = int8(0)
-	m.Int16 = 1
-	m.Int32 = 2
-	m.Int64 = 3
-	m.D = 4.1
-	m.St = "Test"
-	m.Bin = make([]byte, 10)
-	m.StringMap = make(map[string]string, 5)
-	m.StringList = make([]string, 5)
-	m.StringSet = make(map[string]struct{}, 5)
-	m.E = 2
-
-	s, err := t.WriteString(context.Background(), &m)
-	if err != nil {
-		return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
-	}
+	return t
+}
 
-	t1 := NewTDeserializer()
-	t1.Protocol = pf.GetProtocol(t1.Transport)
-	var m1 = MyTestStruct{}
-	if err = t1.ReadString(&m1, s); err != nil {
-		return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
+func poolSerializer(pf ProtocolFactory) serializer {
+	return NewTSerializerPool(
+		func() *TSerializer {
+			return plainSerializer(pf).(*TSerializer)
+		},
+	)
+}
 
-	}
+func plainDeserializer(pf ProtocolFactory) deserializer {
+	d := NewTDeserializer()
+	d.Protocol = pf.GetProtocol(d.Transport)
+	return d
+}
 
-	return compareStructs(m, m1)
+func poolDeserializer(pf ProtocolFactory) deserializer {
+	return NewTDeserializerPool(
+		func() *TDeserializer {
+			return plainDeserializer(pf).(*TDeserializer)
+		},
+	)
+}
 
+type constructors struct {
+	Label        string
+	Serializer   func(pf ProtocolFactory) serializer
+	Deserializer func(pf ProtocolFactory) deserializer
 }
 
-func ProtocolTest2(test *testing.T, pf ProtocolFactory) (bool, error) {
-	t := NewTSerializer()
-	t.Protocol = pf.GetProtocol(t.Transport)
-	var m = MyTestStruct{}
-	m.On = false
-	m.B = int8(0)
-	m.Int16 = 1
-	m.Int32 = 2
-	m.Int64 = 3
-	m.D = 4.1
-	m.St = "Test"
-	m.Bin = make([]byte, 10)
-	m.StringMap = make(map[string]string, 5)
-	m.StringList = make([]string, 5)
-	m.StringSet = make(map[string]struct{}, 5)
-	m.E = 2
-
-	s, err := t.WriteString(context.Background(), &m)
-	if err != nil {
-		return false, errors.New(fmt.Sprintf("Unable to Serialize struct\n\t %s", err))
+var implementations = []constructors{
+	{
+		Label:        "plain",
+		Serializer:   plainSerializer,
+		Deserializer: plainDeserializer,
+	},
+	{
+		Label:        "pool",
+		Serializer:   poolSerializer,
+		Deserializer: poolDeserializer,
+	},
+}
 
-	}
+func ProtocolTest1(t *testing.T, pf ProtocolFactory) {
+	for _, impl := range implementations {
+		t.Run(
+			impl.Label,
+			func(test *testing.T) {
+				t := impl.Serializer(pf)
+				var m = MyTestStruct{}
+				m.On = true
+				m.B = int8(0)
+				m.Int16 = 1
+				m.Int32 = 2
+				m.Int64 = 3
+				m.D = 4.1
+				m.St = "Test"
+				m.Bin = make([]byte, 10)
+				m.StringMap = make(map[string]string, 5)
+				m.StringList = make([]string, 5)
+				m.StringSet = make(map[string]struct{}, 5)
+				m.E = 2
+
+				s, err := t.WriteString(context.Background(), &m)
+				if err != nil {
+					test.Fatalf("Unable to Serialize struct: %v", err)
+
+				}
+
+				t1 := impl.Deserializer(pf)
+				var m1 MyTestStruct
+				if err = t1.ReadString(&m1, s); err != nil {
+					test.Fatalf("Unable to Deserialize struct: %v", err)
 
-	t1 := NewTDeserializer()
-	t1.Protocol = pf.GetProtocol(t1.Transport)
-	var m1 = MyTestStruct{}
-	if err = t1.ReadString(&m1, s); err != nil {
-		return false, errors.New(fmt.Sprintf("Unable to Deserialize struct\n\t %s", err))
+				}
 
+				if err := compareStructs(m, m1); err != nil {
+					test.Error(err)
+				}
+			},
+		)
 	}
+}
+
+func ProtocolTest2(t *testing.T, pf ProtocolFactory) {
+	for _, impl := range implementations {
+		t.Run(
+			impl.Label,
+			func(test *testing.T) {
+				t := impl.Serializer(pf)
+				var m = MyTestStruct{}
+				m.On = false
+				m.B = int8(0)
+				m.Int16 = 1
+				m.Int32 = 2
+				m.Int64 = 3
+				m.D = 4.1
+				m.St = "Test"
+				m.Bin = make([]byte, 10)
+				m.StringMap = make(map[string]string, 5)
+				m.StringList = make([]string, 5)
+				m.StringSet = make(map[string]struct{}, 5)
+				m.E = 2
+
+				s, err := t.WriteString(context.Background(), &m)
+				if err != nil {
+					test.Fatalf("Unable to Serialize struct: %v", err)
+
+				}
 
-	return compareStructs(m, m1)
+				t1 := impl.Deserializer(pf)
+				var m1 MyTestStruct
+				if err = t1.ReadString(&m1, s); err != nil {
+					test.Fatalf("Unable to Deserialize struct: %v", err)
 
+				}
+
+				if err := compareStructs(m, m1); err != nil {
+					test.Error(err)
+				}
+			},
+		)
+	}
 }
 
 func TestSerializer(t *testing.T) {
@@ -150,21 +221,123 @@ func TestSerializer(t *testing.T) {
 	//protocol_factories["SimpleJSON"] = NewTSimpleJSONProtocolFactory() - write only, can't be read back by design
 	protocol_factories["JSON"] = NewTJSONProtocolFactory()
 
-	var tests map[string]func(*testing.T, ProtocolFactory) (bool, error)
-	tests = make(map[string]func(*testing.T, ProtocolFactory) (bool, error))
+	tests := make(map[string]func(*testing.T, ProtocolFactory))
 	tests["Test 1"] = ProtocolTest1
 	tests["Test 2"] = ProtocolTest2
 	//tests["Test 3"] = ProtocolTest3 // Example of how to add additional tests
 
 	for name, pf := range protocol_factories {
+		t.Run(
+			name,
+			func(t *testing.T) {
+				for label, f := range tests {
+					t.Run(
+						label,
+						func(t *testing.T) {
+							f(t, pf)
+						},
+					)
+				}
+			},
+		)
+	}
+
+}
 
-		for test, f := range tests {
+func TestSerializerPoolAsync(t *testing.T) {
+	var wg sync.WaitGroup
+	var counter int64
+	s := NewTSerializerPool(NewTSerializer)
+	d := NewTDeserializerPool(NewTDeserializer)
+	f := func(i int64) bool {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			t.Run(
+				fmt.Sprintf("#%d-%d", atomic.AddInt64(&counter, 1), i),
+				func(t *testing.T) {
+					m := MyTestStruct{
+						Int64: i,
+					}
+					str, err := s.WriteString(context.Background(), &m)
+					if err != nil {
+						t.Fatal("serialize:", err)
+					}
+					var m1 MyTestStruct
+					if err = d.ReadString(&m1, str); err != nil {
+						t.Fatal("deserialize:", err)
 
-			if s, err := f(t, pf); !s || err != nil {
-				t.Errorf("%s Failed for %s protocol\n\t %s", test, name, err)
-			}
+					}
 
-		}
+					if err := compareStructs(m, m1); err != nil {
+						t.Error(err)
+					}
+				},
+			)
+		}()
+		return true
+	}
+	quick.Check(f, nil)
+	wg.Wait()
+}
+
+func BenchmarkSerializer(b *testing.B) {
+	sharedSerializer := NewTSerializer()
+	poolSerializer := NewTSerializerPool(NewTSerializer)
+	sharedDeserializer := NewTDeserializer()
+	poolDeserializer := NewTDeserializerPool(NewTDeserializer)
+
+	cases := []struct {
+		Label        string
+		Serializer   func() serializer
+		Deserializer func() deserializer
+	}{
+		{
+			// Baseline uses shared plain serializer/deserializer
+			Label: "baseline",
+			Serializer: func() serializer {
+				return sharedSerializer
+			},
+			Deserializer: func() deserializer {
+				return sharedDeserializer
+			},
+		},
+		{
+			// Plain creates new serializer/deserializer on every run,
+			// as that's how it's used in real world
+			Label: "plain",
+			Serializer: func() serializer {
+				return NewTSerializer()
+			},
+			Deserializer: func() deserializer {
+				return NewTDeserializer()
+			},
+		},
+		{
+			// Pool uses the shared pool serializer/deserializer
+			Label: "pool",
+			Serializer: func() serializer {
+				return poolSerializer
+			},
+			Deserializer: func() deserializer {
+				return poolDeserializer
+			},
+		},
 	}
 
+	for _, c := range cases {
+		b.Run(
+			c.Label,
+			func(b *testing.B) {
+				for i := 0; i < b.N; i++ {
+					s := c.Serializer()
+					m := MyTestStruct{}
+					str, _ := s.WriteString(context.Background(), &m)
+					var m1 MyTestStruct
+					d := c.Deserializer()
+					d.ReadString(&m1, str)
+				}
+			},
+		)
+	}
 }