You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@thrift.apache.org by yu...@apache.org on 2020/09/22 23:48:51 UTC

[thrift] branch master updated: THRIFT-5278: Allow set protoID in go THeader transport/protocol

This is an automated email from the ASF dual-hosted git repository.

yuxuan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/thrift.git


The following commit(s) were added to refs/heads/master by this push:
     new a2c4466  THRIFT-5278: Allow set protoID in go THeader transport/protocol
a2c4466 is described below

commit a2c44665b416522477cffa6752c2f323768d0507
Author: Yuxuan 'fishy' Wang <yu...@reddit.com>
AuthorDate: Mon Sep 21 12:33:26 2020 -0700

    THRIFT-5278: Allow set protoID in go THeader transport/protocol
    
    Client: go
    
    In Go library code, allow setting the underlying protoID to a
    non-default (TCompactProtocol) one for THeaderTransport/THeaderProtocol.
---
 lib/go/thrift/header_protocol.go       | 60 +++++++++++++++++++++++++++++-----
 lib/go/thrift/header_protocol_test.go  | 18 +++++++++-
 lib/go/thrift/header_transport.go      | 44 +++++++++++++++++++++++--
 lib/go/thrift/header_transport_test.go | 23 ++++++++++---
 4 files changed, 128 insertions(+), 17 deletions(-)

diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 428b261..f86d558 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -37,31 +37,73 @@ type THeaderProtocol struct {
 }
 
 // NewTHeaderProtocol creates a new THeaderProtocol from the underlying
-// transport. The passed in transport will be wrapped with THeaderTransport.
+// transport with default protocol ID.
+//
+// The passed in transport will be wrapped with THeaderTransport.
 //
 // Note that THeaderTransport handles frame and zlib by itself,
 // so the underlying transport should be a raw socket transports (TSocket or TSSLSocket),
 // instead of rich transports like TZlibTransport or TFramedTransport.
 func NewTHeaderProtocol(trans TTransport) *THeaderProtocol {
-	t := NewTHeaderTransport(trans)
-	p, _ := THeaderProtocolDefault.GetProtocol(t)
+	p, err := newTHeaderProtocolWithProtocolID(trans, THeaderProtocolDefault)
+	if err != nil {
+		// Since we used THeaderProtocolDefault this should never happen,
+		// but put a sanity check here just in case.
+		panic(err)
+	}
+	return p
+}
+
+func newTHeaderProtocolWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderProtocol, error) {
+	t, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+	if err != nil {
+		return nil, err
+	}
+	p, err := t.protocolID.GetProtocol(t)
+	if err != nil {
+		return nil, err
+	}
 	return &THeaderProtocol{
 		transport: t,
 		protocol:  p,
-	}
+	}, nil
 }
 
-type tHeaderProtocolFactory struct{}
+type tHeaderProtocolFactory struct {
+	protoID THeaderProtocolID
+}
 
-func (tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
-	return NewTHeaderProtocol(trans)
+func (f tHeaderProtocolFactory) GetProtocol(trans TTransport) TProtocol {
+	p, err := newTHeaderProtocolWithProtocolID(trans, f.protoID)
+	if err != nil {
+		// Currently there's no way for external users to construct a
+		// valid factory with invalid protoID, so this should never
+		// happen. But put a sanity check here just in case in the
+		// future a bug made that possible.
+		panic(err)
+	}
+	return p
 }
 
-// NewTHeaderProtocolFactory creates a factory for THeader.
+// NewTHeaderProtocolFactory creates a factory for THeader with default protocol
+// ID.
 //
 // It's a wrapper for NewTHeaderProtocol
 func NewTHeaderProtocolFactory() TProtocolFactory {
-	return tHeaderProtocolFactory{}
+	return tHeaderProtocolFactory{
+		protoID: THeaderProtocolDefault,
+	}
+}
+
+// NewTHeaderProtocolFactoryWithProtocolID creates a factory for THeader with
+// given protocol ID.
+func NewTHeaderProtocolFactoryWithProtocolID(protoID THeaderProtocolID) (TProtocolFactory, error) {
+	if err := protoID.Validate(); err != nil {
+		return nil, err
+	}
+	return tHeaderProtocolFactory{
+		protoID: protoID,
+	}, nil
 }
 
 // Transport returns the underlying transport.
diff --git a/lib/go/thrift/header_protocol_test.go b/lib/go/thrift/header_protocol_test.go
index 9b6019b..f66ea64 100644
--- a/lib/go/thrift/header_protocol_test.go
+++ b/lib/go/thrift/header_protocol_test.go
@@ -24,5 +24,21 @@ import (
 )
 
 func TestReadWriteHeaderProtocol(t *testing.T) {
-	ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+	t.Run(
+		"default",
+		func(t *testing.T) {
+			ReadWriteProtocolTest(t, NewTHeaderProtocolFactory())
+		},
+	)
+
+	t.Run(
+		"compact",
+		func(t *testing.T) {
+			f, err := NewTHeaderProtocolFactoryWithProtocolID(THeaderProtocolCompact)
+			if err != nil {
+				t.Fatal(err)
+			}
+			ReadWriteProtocolTest(t, f)
+		},
+	)
 }
diff --git a/lib/go/thrift/header_transport.go b/lib/go/thrift/header_transport.go
index e208034..562d02f 100644
--- a/lib/go/thrift/header_transport.go
+++ b/lib/go/thrift/header_transport.go
@@ -75,6 +75,15 @@ const (
 	THeaderProtocolDefault                   = THeaderProtocolBinary
 )
 
+// Declared globally to avoid repetitive allocations, not really used.
+var globalMemoryBuffer = NewTMemoryBuffer()
+
+// Validate checks whether the THeaderProtocolID is a valid/supported one.
+func (id THeaderProtocolID) Validate() error {
+	_, err := id.GetProtocol(globalMemoryBuffer)
+	return err
+}
+
 // GetProtocol gets the corresponding TProtocol from the wrapped protocol id.
 func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
 	switch id {
@@ -84,7 +93,7 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
 			fmt.Sprintf("THeader protocol id %d not supported", id),
 		)
 	case THeaderProtocolBinary:
-		return NewTBinaryProtocolFactoryDefault().GetProtocol(trans), nil
+		return NewTBinaryProtocolTransport(trans), nil
 	case THeaderProtocolCompact:
 		return NewTCompactProtocol(trans), nil
 	}
@@ -93,11 +102,12 @@ func (id THeaderProtocolID) GetProtocol(trans TTransport) (TProtocol, error) {
 // THeaderTransformID defines the numeric id of the transform used.
 type THeaderTransformID int32
 
-// THeaderTransformID values
+// THeaderTransformID values.
+//
+// Values not defined here are not currently supported, namely HMAC and Snappy.
 const (
 	TransformNone THeaderTransformID = iota // 0, no special handling
 	TransformZlib                           // 1, zlib
-	// Rest of the values are not currently supported, namely HMAC and Snappy.
 )
 
 var supportedTransformIDs = map[THeaderTransformID]bool{
@@ -285,6 +295,34 @@ func NewTHeaderTransport(trans TTransport) *THeaderTransport {
 	}
 }
 
+// NewTHeaderTransportWithProtocolID creates THeaderTransport from the
+// underlying transport, with given protocol ID set.
+//
+// If trans is already a *THeaderTransport, it will be returned as is,
+// but with protocol ID overridden by the value passed in.
+//
+// If the passed in protocol ID is an invalid/unsupported one,
+// this function returns error.
+//
+// The protocol ID overridden is only useful for client transports.
+// For servers,
+// the protocol ID will be overridden again to the one set by the client,
+// to ensure that servers always speak the same dialect as the client.
+func NewTHeaderTransportWithProtocolID(trans TTransport, protoID THeaderProtocolID) (*THeaderTransport, error) {
+	if err := protoID.Validate(); err != nil {
+		return nil, err
+	}
+	if ht, ok := trans.(*THeaderTransport); ok {
+		return ht, nil
+	}
+	return &THeaderTransport{
+		transport:    trans,
+		reader:       bufio.NewReader(trans),
+		writeHeaders: make(THeaderMap),
+		protocolID:   protoID,
+	}, nil
+}
+
 // Open calls the underlying transport's Open function.
 func (t *THeaderTransport) Open() error {
 	return t.transport.Open()
diff --git a/lib/go/thrift/header_transport_test.go b/lib/go/thrift/header_transport_test.go
index 320fb2a..5b47680 100644
--- a/lib/go/thrift/header_transport_test.go
+++ b/lib/go/thrift/header_transport_test.go
@@ -28,10 +28,13 @@ import (
 	"testing/quick"
 )
 
-func TestTHeaderHeadersReadWrite(t *testing.T) {
+func testTHeaderHeadersReadWriteProtocolID(t *testing.T, protoID THeaderProtocolID) {
 	trans := NewTMemoryBuffer()
 	reader := NewTHeaderTransport(trans)
-	writer := NewTHeaderTransport(trans)
+	writer, err := NewTHeaderTransportWithProtocolID(trans, protoID)
+	if err != nil {
+		t.Fatal(err)
+	}
 
 	const key1 = "key1"
 	const value1 = "value1"
@@ -98,10 +101,10 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
 			read,
 		)
 	}
-	if prot := reader.Protocol(); prot != THeaderProtocolBinary {
+	if prot := reader.Protocol(); prot != protoID {
 		t.Errorf(
 			"reader.Protocol() expected %d, got %d",
-			THeaderProtocolBinary,
+			protoID,
 			prot,
 		)
 	}
@@ -121,6 +124,18 @@ func TestTHeaderHeadersReadWrite(t *testing.T) {
 	}
 }
 
+func TestTHeaderHeadersReadWrite(t *testing.T) {
+	for label, id := range map[string]THeaderProtocolID{
+		"default": THeaderProtocolDefault,
+		"binary":  THeaderProtocolBinary,
+		"compact": THeaderProtocolCompact,
+	} {
+		t.Run(label, func(t *testing.T) {
+			testTHeaderHeadersReadWriteProtocolID(t, id)
+		})
+	}
+}
+
 func TestTHeaderTransportNoDoubleWrapping(t *testing.T) {
 	trans := NewTMemoryBuffer()
 	orig := NewTHeaderTransport(trans)