You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@inlong.apache.org by go...@apache.org on 2021/05/10 10:31:54 UTC
[incubator-inlong] 01/12: [INLONG-600]Multiplexed connection pool
for Go sdk
This is an automated email from the ASF dual-hosted git repository.
gosonzhang pushed a commit to branch INLONG-25
in repository https://gitbox.apache.org/repos/asf/incubator-inlong.git
commit 108fd1fbefedc73fd38aaf699a0cd1e33f3bb03a
Author: Zijie Lu <ws...@gmail.com>
AuthorDate: Fri Apr 30 14:37:41 2021 +0800
[INLONG-600]Multiplexed connection pool for Go sdk
Signed-off-by: Zijie Lu <ws...@gmail.com>
---
.../tubemq-client-go/codec/codec.go | 107 ++++++
tubemq-client-twins/tubemq-client-go/go.mod | 5 +
.../tubemq-client-go/pool/multiplexed.go | 386 +++++++++++++++++++++
.../tubemq-client-go/pool/multlplexed_test.go | 119 +++++++
4 files changed, 617 insertions(+)
diff --git a/tubemq-client-twins/tubemq-client-go/codec/codec.go b/tubemq-client-twins/tubemq-client-go/codec/codec.go
new file mode 100644
index 0000000..ee27e96
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/codec/codec.go
@@ -0,0 +1,107 @@
+package codec
+
+import (
+ "bufio"
+ "encoding/binary"
+ "errors"
+ "io"
+)
+
+const (
+ RPCProtocolBeginToken uint32 = 0xFF7FF4FE
+ RPCMaxBufferSize uint32 = 8192
+ frameHeadLen uint32 = 8
+ maxBufferSize int = 128 * 1024
+ defaultMsgSize int = 4096
+ dataLen uint32 = 4
+ listSizeLen uint32 = 4
+ serialNoLen uint32 = 4
+ beginTokenLen uint32 = 4
+)
+
+type Framer struct {
+ reader io.Reader
+ msg []byte
+}
+
+func New(reader io.Reader) *Framer {
+ bufferReader := bufio.NewReaderSize(reader, maxBufferSize)
+ return &Framer{
+ msg: make([]byte, defaultMsgSize),
+ reader: bufferReader,
+ }
+}
+
+func (f *Framer) Decode() (*FrameResponse, error) {
+ num, err := io.ReadFull(f.reader, f.msg[:frameHeadLen])
+ if err != nil {
+ return nil, err
+ }
+ if num != int(frameHeadLen) {
+ return nil, errors.New("framer: read frame header num invalid")
+ }
+ token := binary.BigEndian.Uint32(f.msg[:beginTokenLen])
+ if token != RPCProtocolBeginToken {
+ return nil, errors.New("framer: read framer rpc protocol begin token not match")
+ }
+ num, err = io.ReadFull(f.reader, f.msg[frameHeadLen:frameHeadLen+listSizeLen])
+ if num != int(listSizeLen) {
+ return nil, errors.New("framer: read invalid list size num")
+ }
+ listSize := binary.BigEndian.Uint32(f.msg[frameHeadLen : frameHeadLen+listSizeLen])
+ totalLen := int(frameHeadLen)
+ size := make([]byte, 4)
+ for i := 0; i < int(listSize); i++ {
+ n, err := io.ReadFull(f.reader, size)
+ if err != nil {
+ return nil, err
+ }
+ if n != int(dataLen) {
+ return nil, errors.New("framer: read invalid size")
+ }
+
+ s := int(binary.BigEndian.Uint32(size))
+ if totalLen+s > len(f.msg) {
+ data := f.msg[:totalLen]
+ f.msg = make([]byte, totalLen+s)
+ copy(f.msg, data[:])
+ }
+
+ num, err = io.ReadFull(f.reader, f.msg[totalLen:totalLen+s])
+ if err != nil {
+ return nil, err
+ }
+ if num != s {
+ return nil, errors.New("framer: read invalid data")
+ }
+ totalLen += s
+ }
+
+ data := make([]byte, totalLen - int(frameHeadLen))
+ copy(data, f.msg[frameHeadLen:totalLen])
+
+ return &FrameResponse{
+ serialNo: binary.BigEndian.Uint32(f.msg[beginTokenLen : beginTokenLen+serialNoLen]),
+ responseBuf: data,
+ }, nil
+}
+
+type FrameRequest struct {
+ requestID uint32
+ req []byte
+}
+
+type FrameResponse struct {
+ serialNo uint32
+ responseBuf []byte
+}
+
+func (f *FrameResponse) GetSerialNo() uint32 {
+ return f.serialNo
+}
+
+func (f *FrameResponse) GetResponseBuf() []byte {
+ return f.responseBuf
+}
+
+type Codec struct{}
diff --git a/tubemq-client-twins/tubemq-client-go/go.mod b/tubemq-client-twins/tubemq-client-go/go.mod
new file mode 100644
index 0000000..7c1a676
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/go.mod
@@ -0,0 +1,5 @@
+module github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go
+
+go 1.14
+
+require github.com/stretchr/testify v1.7.0
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
new file mode 100644
index 0000000..5d38a14
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
@@ -0,0 +1,386 @@
+package pool
+
+import (
+ "context"
+ "crypto/tls"
+ "crypto/x509"
+ "errors"
+ "io/ioutil"
+ "net"
+ "sync"
+ "time"
+
+ "github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var DefaultMultiplexedPool = New()
+
+var (
+ // ErrConnClosed indicates that the connection is closed
+ ErrConnClosed = errors.New("connection is closed")
+ // ErrChanClose indicates the recv chan is closed
+ ErrChanClose = errors.New("unexpected recv chan close")
+ // ErrWriteBufferDone indicates write buffer done
+ ErrWriteBufferDone = errors.New("write buffer done")
+ // ErrAssertConnectionFail indicates connection assertion error
+ ErrAssertConnectionFail = errors.New("assert connection slice fail")
+)
+
+const (
+ Initial int = iota
+ Connected
+ Closing
+ Closed
+)
+
+var queueSize = 10000
+
+func New() *Multiplexed {
+ m := &Multiplexed{
+ connections: new(sync.Map),
+ }
+ return m
+}
+
+type writerBuffer struct {
+ buffer chan []byte
+ done <-chan struct{}
+}
+
+func (w *writerBuffer) get() ([]byte, error) {
+ select {
+ case req := <-w.buffer:
+ return req, nil
+ case <-w.done:
+ return nil, ErrWriteBufferDone
+ }
+}
+
+type recvReader struct {
+ ctx context.Context
+ recv chan *codec.FrameResponse
+}
+
+type MultiplexedConnection struct {
+ serialNo uint32
+ conn *Connection
+ reader *recvReader
+ done chan struct{}
+}
+
+func (mc *MultiplexedConnection) Write(b []byte) error {
+ if err := mc.conn.send(b); err != nil {
+ mc.conn.remove(mc.serialNo)
+ return err
+ }
+ return nil
+}
+
+func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) {
+ select {
+ case <-mc.reader.ctx.Done():
+ mc.conn.remove(mc.serialNo)
+ return nil, mc.reader.ctx.Err()
+ case v, ok := <-mc.reader.recv:
+ if ok {
+ return v, nil
+ }
+ if mc.conn.err != nil {
+ return nil, mc.conn.err
+ }
+ return nil, ErrChanClose
+ case <-mc.done:
+ return nil, mc.conn.err
+ }
+}
+
+func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) {
+ mc.reader.recv <- rsp
+ mc.conn.remove(rsp.GetSerialNo())
+}
+
+type DialOptions struct {
+ Network string
+ Address string
+ Timeout time.Duration
+ CACertFile string
+ TLSCertFile string
+ TLSKeyFile string
+ TLSServerName string
+}
+
+type Connection struct {
+ err error
+ address string
+ mu sync.RWMutex
+ connections map[uint32]*MultiplexedConnection
+ framer *codec.Framer
+ conn net.Conn
+ done chan struct{}
+ mDone chan struct{}
+ buffer *writerBuffer
+ dialOpts *DialOptions
+ state int
+ multiplexed *Multiplexed
+}
+
+func (c *Connection) new(ctx context.Context, serialNo uint32) (*MultiplexedConnection, error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.err != nil {
+ return nil, c.err
+ }
+
+ vc := &MultiplexedConnection{
+ serialNo: serialNo,
+ conn: c,
+ done: c.mDone,
+ reader: &recvReader{
+ ctx: ctx,
+ recv: make(chan *codec.FrameResponse, 1),
+ },
+ }
+
+ if prevConn, ok := c.connections[serialNo]; ok {
+ close(prevConn.reader.recv)
+ }
+ c.connections[serialNo] = vc
+ return vc, nil
+}
+
+func (c *Connection) close(lastErr error, done chan struct{}) {
+ if lastErr == nil {
+ return
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.state == Closed {
+ return
+ }
+
+ select {
+ case <-done:
+ return
+ default:
+ }
+
+ c.state = Closing
+ c.err = lastErr
+ c.connections = make(map[uint32]*MultiplexedConnection)
+ close(c.done)
+ if c.conn != nil {
+ c.conn.Close()
+ }
+ err := c.reconnect()
+ if err != nil {
+ c.state = Closed
+ close(c.mDone)
+ c.multiplexed.connections.Delete(c)
+ }
+}
+
+func (c *Connection) reconnect() error {
+ conn, err := dialWithTimeout(c.dialOpts)
+ if err != nil {
+ return err
+ }
+ c.done = make(chan struct{})
+ c.conn = conn
+ c.framer = codec.New(conn)
+ c.buffer.done = c.done
+ c.state = Connected
+ c.err = nil
+ go c.reader()
+ go c.writer()
+ return nil
+}
+
+func (c *Connection) writer() {
+ var lastErr error
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ }
+ req, err := c.buffer.get()
+ if err != nil {
+ lastErr = err
+ break
+ }
+ if err := c.write(req); err != nil {
+ lastErr = err
+ break
+ }
+ }
+ c.close(lastErr, c.done)
+}
+
+func (c *Connection) send(b []byte) error {
+ if c.state == Closed {
+ return ErrConnClosed
+ }
+
+ select {
+ case c.buffer.buffer <- b:
+ return nil
+ case <-c.mDone:
+ return c.err
+ }
+}
+
+func (c *Connection) remove(id uint32) {
+ c.mu.Lock()
+ delete(c.connections, id)
+ c.mu.Unlock()
+}
+
+func (c *Connection) write(b []byte) error {
+ sent := 0
+ for sent < len(b) {
+ n, err := c.conn.Write(b[sent:])
+ if err != nil {
+ return err
+ }
+ sent += n
+ }
+ return nil
+}
+
+func (c *Connection) reader() {
+ var lastErr error
+ for {
+ select {
+ case <-c.done:
+ return
+ default:
+ }
+ rsp, err := c.framer.Decode()
+ if err != nil {
+ lastErr = err
+ break
+ }
+ serialNo := rsp.GetSerialNo()
+ c.mu.RLock()
+ mc, ok := c.connections[serialNo]
+ c.mu.RUnlock()
+ if !ok {
+ continue
+ }
+ mc.reader.recv <- rsp
+ mc.conn.remove(rsp.GetSerialNo())
+ }
+ c.close(lastErr, c.done)
+}
+
+type Multiplexed struct {
+ connections *sync.Map
+}
+
+func (p *Multiplexed) Get(ctx context.Context, address string, serialNo uint32) (*MultiplexedConnection, error) {
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ default:
+ }
+
+ if v, ok := p.connections.Load(address); ok {
+ if c, ok := v.(*Connection); ok {
+ return c.new(ctx, serialNo)
+ }
+ return nil, ErrAssertConnectionFail
+ }
+
+ c := &Connection{
+ address: address,
+ connections: make(map[uint32]*MultiplexedConnection),
+ done: make(chan struct{}),
+ mDone: make(chan struct{}),
+ state: Initial,
+ }
+ c.buffer = &writerBuffer{
+ buffer: make(chan []byte, queueSize),
+ done: c.done,
+ }
+ p.connections.Store(address, c)
+
+ conn, dialOpts, err := dial(ctx, address)
+ c.dialOpts = dialOpts
+ if err != nil {
+ return nil, err
+ }
+ c.framer = codec.New(conn)
+ c.conn = conn
+ c.state = Connected
+ go c.reader()
+ go c.writer()
+ return c.new(ctx, serialNo)
+}
+
+func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error) {
+ var timeout time.Duration
+ t, ok := ctx.Deadline()
+ if ok {
+ timeout = t.Sub(time.Now())
+ }
+ dialOpts := &DialOptions{
+ Network: "tcp",
+ Address: address,
+ Timeout: timeout,
+ }
+ select {
+ case <-ctx.Done():
+ return nil, dialOpts, ctx.Err()
+ default:
+ }
+ conn, err := dialWithTimeout(dialOpts)
+ return conn, dialOpts, err
+}
+
+func dialWithTimeout(opts *DialOptions) (net.Conn, error) {
+ if len(opts.CACertFile) == 0 {
+ return net.DialTimeout(opts.Network, opts.Address, opts.Timeout)
+ }
+
+ tlsConf := &tls.Config{}
+ if opts.CACertFile == "none" { // 不需要检验服务证书
+ tlsConf.InsecureSkipVerify = true
+ } else {
+ if len(opts.TLSServerName) == 0 {
+ opts.TLSServerName = opts.Address
+ }
+ tlsConf.ServerName = opts.TLSServerName
+ certPool, err := getCertPool(opts.CACertFile)
+ if err != nil {
+ return nil, err
+ }
+
+ tlsConf.RootCAs = certPool
+
+ if len(opts.TLSCertFile) != 0 {
+ cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile)
+ if err != nil {
+ return nil, err
+ }
+ tlsConf.Certificates = []tls.Certificate{cert}
+ }
+ }
+ return tls.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, opts.Address, tlsConf)
+}
+
+func getCertPool(caCertFile string) (*x509.CertPool, error) {
+ if caCertFile != "root" {
+ ca, err := ioutil.ReadFile(caCertFile)
+ if err != nil {
+ return nil, err
+ }
+ certPool := x509.NewCertPool()
+ ok := certPool.AppendCertsFromPEM(ca)
+ if !ok {
+ return nil, err
+ }
+ return certPool, nil
+ }
+ return nil, nil
+}
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
new file mode 100644
index 0000000..6377032
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
@@ -0,0 +1,119 @@
+package pool
+
+import (
+ "bytes"
+ "context"
+ "encoding/binary"
+ "io"
+ "log"
+ "net"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+
+ "github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var (
+ address = "127.0.0.1:0"
+ ch = make(chan struct{})
+ serialNo uint32 = 1
+)
+
+func init() {
+ go simpleForwardTCPServer(ch)
+ <-ch
+}
+
+func simpleForwardTCPServer(ch chan struct{}) {
+ l, err := net.Listen("tcp", address)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer l.Close()
+ address = l.Addr().String()
+
+ ch <- struct{}{}
+
+ for {
+ conn, err := l.Accept()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ go func() {
+ io.Copy(conn, conn)
+ }()
+ }
+}
+
+func Encode(serialNo uint32, body []byte) ([]byte, error) {
+ l := len(body)
+ buf := bytes.NewBuffer(make([]byte, 0, 16+l))
+ if err := binary.Write(buf, binary.BigEndian, codec.RPCProtocolBeginToken); err != nil {
+ return nil, err
+ }
+ if err := binary.Write(buf, binary.BigEndian, serialNo); err != nil {
+ return nil, err
+ }
+ if err := binary.Write(buf, binary.BigEndian, uint32(1)); err != nil {
+ return nil, err
+ }
+ if err := binary.Write(buf, binary.BigEndian, uint32(len(body))); err != nil {
+ return nil, err
+ }
+ buf.Write(body)
+ return buf.Bytes(), nil
+}
+
+func TestBasicMultiplexed(t *testing.T) {
+ serialNo := atomic.AddUint32(&serialNo, 1)
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+
+ m := New()
+ mc, err := m.Get(ctx, address, serialNo)
+ body := []byte("hello world")
+
+ buf, err := Encode(serialNo, body)
+ assert.Nil(t, err)
+ assert.Nil(t, mc.Write(buf))
+
+ rsp, err := mc.Read()
+ assert.Nil(t, err)
+ assert.Equal(t, serialNo, rsp.GetSerialNo())
+ assert.Equal(t, body, rsp.GetResponseBuf())
+ assert.Equal(t, mc.Write(nil), nil)
+}
+
+func TestConcurrentMultiplexed(t *testing.T) {
+ count := 1000
+ m := New()
+ wg := sync.WaitGroup{}
+ wg.Add(count)
+ for i := 0; i < count; i++ {
+ go func(i int) {
+ defer wg.Done()
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+ defer cancel()
+ serialNo := atomic.AddUint32(&serialNo, 1)
+ mc, err := m.Get(ctx, address, serialNo)
+ assert.Nil(t, err)
+
+ body := []byte("hello world" + strconv.Itoa(i))
+ buf, err := Encode(serialNo, body)
+ assert.Nil(t, err)
+ assert.Nil(t, mc.Write(buf))
+
+ rsp, err := mc.Read()
+ assert.Nil(t, err)
+ assert.Equal(t, serialNo, rsp.GetSerialNo())
+ assert.Equal(t, body, rsp.GetResponseBuf())
+ }(i)
+ }
+ wg.Wait()
+}