You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@inlong.apache.org by go...@apache.org on 2021/05/10 10:31:54 UTC

[incubator-inlong] 01/12: [INLONG-600]Multiplexed connection pool for Go sdk

This is an automated email from the ASF dual-hosted git repository.

gosonzhang pushed a commit to branch INLONG-25
in repository https://gitbox.apache.org/repos/asf/incubator-inlong.git

commit 108fd1fbefedc73fd38aaf699a0cd1e33f3bb03a
Author: Zijie Lu <ws...@gmail.com>
AuthorDate: Fri Apr 30 14:37:41 2021 +0800

    [INLONG-600]Multiplexed connection pool for Go sdk
    
    Signed-off-by: Zijie Lu <ws...@gmail.com>
---
 .../tubemq-client-go/codec/codec.go                | 107 ++++++
 tubemq-client-twins/tubemq-client-go/go.mod        |   5 +
 .../tubemq-client-go/pool/multiplexed.go           | 386 +++++++++++++++++++++
 .../tubemq-client-go/pool/multlplexed_test.go      | 119 +++++++
 4 files changed, 617 insertions(+)

diff --git a/tubemq-client-twins/tubemq-client-go/codec/codec.go b/tubemq-client-twins/tubemq-client-go/codec/codec.go
new file mode 100644
index 0000000..ee27e96
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/codec/codec.go
@@ -0,0 +1,107 @@
+package codec
+
+import (
+	"bufio"
+	"encoding/binary"
+	"errors"
+	"io"
+)
+
+const (
+	RPCProtocolBeginToken uint32 = 0xFF7FF4FE
+	RPCMaxBufferSize      uint32 = 8192
+	frameHeadLen          uint32 = 8
+	maxBufferSize         int    = 128 * 1024
+	defaultMsgSize        int    = 4096
+	dataLen               uint32 = 4
+	listSizeLen           uint32 = 4
+	serialNoLen           uint32 = 4
+	beginTokenLen         uint32 = 4
+)
+
+type Framer struct {
+	reader io.Reader
+	msg    []byte
+}
+
+func New(reader io.Reader) *Framer {
+	bufferReader := bufio.NewReaderSize(reader, maxBufferSize)
+	return &Framer{
+		msg:    make([]byte, defaultMsgSize),
+		reader: bufferReader,
+	}
+}
+
+func (f *Framer) Decode() (*FrameResponse, error) {
+	num, err := io.ReadFull(f.reader, f.msg[:frameHeadLen])
+	if err != nil {
+		return nil, err
+	}
+	if num != int(frameHeadLen) {
+		return nil, errors.New("framer: read frame header num invalid")
+	}
+	token := binary.BigEndian.Uint32(f.msg[:beginTokenLen])
+	if token != RPCProtocolBeginToken {
+		return nil, errors.New("framer: read framer rpc protocol begin token not match")
+	}
+	num, err = io.ReadFull(f.reader, f.msg[frameHeadLen:frameHeadLen+listSizeLen])
+	if num != int(listSizeLen) {
+		return nil, errors.New("framer: read invalid list size num")
+	}
+	listSize := binary.BigEndian.Uint32(f.msg[frameHeadLen : frameHeadLen+listSizeLen])
+	totalLen := int(frameHeadLen)
+	size := make([]byte, 4)
+	for i := 0; i < int(listSize); i++ {
+		n, err := io.ReadFull(f.reader, size)
+		if err != nil {
+			return nil, err
+		}
+		if n != int(dataLen) {
+			return nil, errors.New("framer: read invalid size")
+		}
+
+		s := int(binary.BigEndian.Uint32(size))
+		if totalLen+s > len(f.msg) {
+			data := f.msg[:totalLen]
+			f.msg = make([]byte, totalLen+s)
+			copy(f.msg, data[:])
+		}
+
+		num, err = io.ReadFull(f.reader, f.msg[totalLen:totalLen+s])
+		if err != nil {
+			return nil, err
+		}
+		if num != s {
+			return nil, errors.New("framer: read invalid data")
+		}
+		totalLen += s
+	}
+
+	data := make([]byte, totalLen - int(frameHeadLen))
+	copy(data, f.msg[frameHeadLen:totalLen])
+
+	return &FrameResponse{
+		serialNo:    binary.BigEndian.Uint32(f.msg[beginTokenLen : beginTokenLen+serialNoLen]),
+		responseBuf: data,
+	}, nil
+}
+
+type FrameRequest struct {
+	requestID uint32
+	req       []byte
+}
+
+type FrameResponse struct {
+	serialNo    uint32
+	responseBuf []byte
+}
+
+func (f *FrameResponse) GetSerialNo() uint32 {
+	return f.serialNo
+}
+
+func (f *FrameResponse) GetResponseBuf() []byte {
+	return f.responseBuf
+}
+
+type Codec struct{}
diff --git a/tubemq-client-twins/tubemq-client-go/go.mod b/tubemq-client-twins/tubemq-client-go/go.mod
new file mode 100644
index 0000000..7c1a676
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/go.mod
@@ -0,0 +1,5 @@
+module github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go
+
+go 1.14
+
+require github.com/stretchr/testify v1.7.0
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
new file mode 100644
index 0000000..5d38a14
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multiplexed.go
@@ -0,0 +1,386 @@
+package pool
+
+import (
+	"context"
+	"crypto/tls"
+	"crypto/x509"
+	"errors"
+	"io/ioutil"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var DefaultMultiplexedPool = New()
+
+var (
+	// ErrConnClosed indicates that the connection is closed
+	ErrConnClosed = errors.New("connection is closed")
+	// ErrChanClose indicates the recv chan is closed
+	ErrChanClose = errors.New("unexpected recv chan close")
+	// ErrWriteBufferDone indicates write buffer done
+	ErrWriteBufferDone = errors.New("write buffer done")
+	// ErrAssertConnectionFail indicates connection assertion error
+	ErrAssertConnectionFail = errors.New("assert connection slice fail")
+)
+
+const (
+	Initial int = iota
+	Connected
+	Closing
+	Closed
+)
+
+var queueSize = 10000
+
+func New() *Multiplexed {
+	m := &Multiplexed{
+		connections: new(sync.Map),
+	}
+	return m
+}
+
+type writerBuffer struct {
+	buffer chan []byte
+	done   <-chan struct{}
+}
+
+func (w *writerBuffer) get() ([]byte, error) {
+	select {
+	case req := <-w.buffer:
+		return req, nil
+	case <-w.done:
+		return nil, ErrWriteBufferDone
+	}
+}
+
+type recvReader struct {
+	ctx  context.Context
+	recv chan *codec.FrameResponse
+}
+
+type MultiplexedConnection struct {
+	serialNo uint32
+	conn     *Connection
+	reader   *recvReader
+	done     chan struct{}
+}
+
+func (mc *MultiplexedConnection) Write(b []byte) error {
+	if err := mc.conn.send(b); err != nil {
+		mc.conn.remove(mc.serialNo)
+		return err
+	}
+	return nil
+}
+
+func (mc *MultiplexedConnection) Read() (*codec.FrameResponse, error) {
+	select {
+	case <-mc.reader.ctx.Done():
+		mc.conn.remove(mc.serialNo)
+		return nil, mc.reader.ctx.Err()
+	case v, ok := <-mc.reader.recv:
+		if ok {
+			return v, nil
+		}
+		if mc.conn.err != nil {
+			return nil, mc.conn.err
+		}
+		return nil, ErrChanClose
+	case <-mc.done:
+		return nil, mc.conn.err
+	}
+}
+
+func (mc *MultiplexedConnection) recv(rsp *codec.FrameResponse) {
+	mc.reader.recv <- rsp
+	mc.conn.remove(rsp.GetSerialNo())
+}
+
+type DialOptions struct {
+	Network       string
+	Address       string
+	Timeout       time.Duration
+	CACertFile    string
+	TLSCertFile   string
+	TLSKeyFile    string
+	TLSServerName string
+}
+
+type Connection struct {
+	err         error
+	address     string
+	mu          sync.RWMutex
+	connections map[uint32]*MultiplexedConnection
+	framer      *codec.Framer
+	conn        net.Conn
+	done        chan struct{}
+	mDone       chan struct{}
+	buffer      *writerBuffer
+	dialOpts    *DialOptions
+	state       int
+	multiplexed *Multiplexed
+}
+
+func (c *Connection) new(ctx context.Context, serialNo uint32) (*MultiplexedConnection, error) {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.err != nil {
+		return nil, c.err
+	}
+
+	vc := &MultiplexedConnection{
+		serialNo: serialNo,
+		conn:     c,
+		done:     c.mDone,
+		reader: &recvReader{
+			ctx:  ctx,
+			recv: make(chan *codec.FrameResponse, 1),
+		},
+	}
+
+	if prevConn, ok := c.connections[serialNo]; ok {
+		close(prevConn.reader.recv)
+	}
+	c.connections[serialNo] = vc
+	return vc, nil
+}
+
+func (c *Connection) close(lastErr error, done chan struct{}) {
+	if lastErr == nil {
+		return
+	}
+	c.mu.Lock()
+	defer c.mu.Unlock()
+
+	if c.state == Closed {
+		return
+	}
+
+	select {
+	case <-done:
+		return
+	default:
+	}
+
+	c.state = Closing
+	c.err = lastErr
+	c.connections = make(map[uint32]*MultiplexedConnection)
+	close(c.done)
+	if c.conn != nil {
+		c.conn.Close()
+	}
+	err := c.reconnect()
+	if err != nil {
+		c.state = Closed
+		close(c.mDone)
+		c.multiplexed.connections.Delete(c)
+	}
+}
+
+func (c *Connection) reconnect() error {
+	conn, err := dialWithTimeout(c.dialOpts)
+	if err != nil {
+		return err
+	}
+	c.done = make(chan struct{})
+	c.conn = conn
+	c.framer = codec.New(conn)
+	c.buffer.done = c.done
+	c.state = Connected
+	c.err = nil
+	go c.reader()
+	go c.writer()
+	return nil
+}
+
+func (c *Connection) writer() {
+	var lastErr error
+	for {
+		select {
+		case <-c.done:
+			return
+		default:
+		}
+		req, err := c.buffer.get()
+		if err != nil {
+			lastErr = err
+			break
+		}
+		if err := c.write(req); err != nil {
+			lastErr = err
+			break
+		}
+	}
+	c.close(lastErr, c.done)
+}
+
+func (c *Connection) send(b []byte) error {
+	if c.state == Closed {
+		return ErrConnClosed
+	}
+
+	select {
+	case c.buffer.buffer <- b:
+		return nil
+	case <-c.mDone:
+		return c.err
+	}
+}
+
+func (c *Connection) remove(id uint32) {
+	c.mu.Lock()
+	delete(c.connections, id)
+	c.mu.Unlock()
+}
+
+func (c *Connection) write(b []byte) error {
+	sent := 0
+	for sent < len(b) {
+		n, err := c.conn.Write(b[sent:])
+		if err != nil {
+			return err
+		}
+		sent += n
+	}
+	return nil
+}
+
+func (c *Connection) reader() {
+	var lastErr error
+	for {
+		select {
+		case <-c.done:
+			return
+		default:
+		}
+		rsp, err := c.framer.Decode()
+		if err != nil {
+			lastErr = err
+			break
+		}
+		serialNo := rsp.GetSerialNo()
+		c.mu.RLock()
+		mc, ok := c.connections[serialNo]
+		c.mu.RUnlock()
+		if !ok {
+			continue
+		}
+		mc.reader.recv <- rsp
+		mc.conn.remove(rsp.GetSerialNo())
+	}
+	c.close(lastErr, c.done)
+}
+
+type Multiplexed struct {
+	connections *sync.Map
+}
+
+func (p *Multiplexed) Get(ctx context.Context, address string, serialNo uint32) (*MultiplexedConnection, error) {
+	select {
+	case <-ctx.Done():
+		return nil, ctx.Err()
+	default:
+	}
+
+	if v, ok := p.connections.Load(address); ok {
+		if c, ok := v.(*Connection); ok {
+			return c.new(ctx, serialNo)
+		}
+		return nil, ErrAssertConnectionFail
+	}
+
+	c := &Connection{
+		address:     address,
+		connections: make(map[uint32]*MultiplexedConnection),
+		done:        make(chan struct{}),
+		mDone:       make(chan struct{}),
+		state:       Initial,
+	}
+	c.buffer = &writerBuffer{
+		buffer: make(chan []byte, queueSize),
+		done:   c.done,
+	}
+	p.connections.Store(address, c)
+
+	conn, dialOpts, err := dial(ctx, address)
+	c.dialOpts = dialOpts
+	if err != nil {
+		return nil, err
+	}
+	c.framer = codec.New(conn)
+	c.conn = conn
+	c.state = Connected
+	go c.reader()
+	go c.writer()
+	return c.new(ctx, serialNo)
+}
+
+func dial(ctx context.Context, address string) (net.Conn, *DialOptions, error) {
+	var timeout time.Duration
+	t, ok := ctx.Deadline()
+	if ok {
+		timeout = t.Sub(time.Now())
+	}
+	dialOpts := &DialOptions{
+		Network: "tcp",
+		Address: address,
+		Timeout: timeout,
+	}
+	select {
+	case <-ctx.Done():
+		return nil, dialOpts, ctx.Err()
+	default:
+	}
+	conn, err := dialWithTimeout(dialOpts)
+	return conn, dialOpts, err
+}
+
+func dialWithTimeout(opts *DialOptions) (net.Conn, error) {
+	if len(opts.CACertFile) == 0 {
+		return net.DialTimeout(opts.Network, opts.Address, opts.Timeout)
+	}
+
+	tlsConf := &tls.Config{}
+	if opts.CACertFile == "none" { // 不需要检验服务证书
+		tlsConf.InsecureSkipVerify = true
+	} else {
+		if len(opts.TLSServerName) == 0 {
+			opts.TLSServerName = opts.Address
+		}
+		tlsConf.ServerName = opts.TLSServerName
+		certPool, err := getCertPool(opts.CACertFile)
+		if err != nil {
+			return nil, err
+		}
+
+		tlsConf.RootCAs = certPool
+
+		if len(opts.TLSCertFile) != 0 {
+			cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile)
+			if err != nil {
+				return nil, err
+			}
+			tlsConf.Certificates = []tls.Certificate{cert}
+		}
+	}
+	return tls.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, opts.Address, tlsConf)
+}
+
+func getCertPool(caCertFile string) (*x509.CertPool, error) {
+	if caCertFile != "root" {
+		ca, err := ioutil.ReadFile(caCertFile)
+		if err != nil {
+			return nil, err
+		}
+		certPool := x509.NewCertPool()
+		ok := certPool.AppendCertsFromPEM(ca)
+		if !ok {
+			return nil, err
+		}
+		return certPool, nil
+	}
+	return nil, nil
+}
diff --git a/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
new file mode 100644
index 0000000..6377032
--- /dev/null
+++ b/tubemq-client-twins/tubemq-client-go/pool/multlplexed_test.go
@@ -0,0 +1,119 @@
+package pool
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"io"
+	"log"
+	"net"
+	"strconv"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/assert"
+
+	"github.com/apache/incubator-inlong/tubemq-client-twins/tubemq-client-go/codec"
+)
+
+var (
+	address         = "127.0.0.1:0"
+	ch              = make(chan struct{})
+	serialNo uint32 = 1
+)
+
+func init() {
+	go simpleForwardTCPServer(ch)
+	<-ch
+}
+
+func simpleForwardTCPServer(ch chan struct{}) {
+	l, err := net.Listen("tcp", address)
+	if err != nil {
+		log.Fatal(err)
+	}
+	defer l.Close()
+	address = l.Addr().String()
+
+	ch <- struct{}{}
+
+	for {
+		conn, err := l.Accept()
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		go func() {
+			io.Copy(conn, conn)
+		}()
+	}
+}
+
+func Encode(serialNo uint32, body []byte) ([]byte, error) {
+	l := len(body)
+	buf := bytes.NewBuffer(make([]byte, 0, 16+l))
+	if err := binary.Write(buf, binary.BigEndian, codec.RPCProtocolBeginToken); err != nil {
+		return nil, err
+	}
+	if err := binary.Write(buf, binary.BigEndian, serialNo); err != nil {
+		return nil, err
+	}
+	if err := binary.Write(buf, binary.BigEndian, uint32(1)); err != nil {
+		return nil, err
+	}
+	if err := binary.Write(buf, binary.BigEndian, uint32(len(body))); err != nil {
+		return nil, err
+	}
+	buf.Write(body)
+	return buf.Bytes(), nil
+}
+
+func TestBasicMultiplexed(t *testing.T) {
+	serialNo := atomic.AddUint32(&serialNo, 1)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+	defer cancel()
+
+	m := New()
+	mc, err := m.Get(ctx, address, serialNo)
+	body := []byte("hello world")
+
+	buf, err := Encode(serialNo, body)
+	assert.Nil(t, err)
+	assert.Nil(t, mc.Write(buf))
+
+	rsp, err := mc.Read()
+	assert.Nil(t, err)
+	assert.Equal(t, serialNo, rsp.GetSerialNo())
+	assert.Equal(t, body, rsp.GetResponseBuf())
+	assert.Equal(t, mc.Write(nil), nil)
+}
+
+func TestConcurrentMultiplexed(t *testing.T) {
+	count := 1000
+	m := New()
+	wg := sync.WaitGroup{}
+	wg.Add(count)
+	for i := 0; i < count; i++ {
+		go func(i int) {
+			defer wg.Done()
+			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+			defer cancel()
+			serialNo := atomic.AddUint32(&serialNo, 1)
+			mc, err := m.Get(ctx, address, serialNo)
+			assert.Nil(t, err)
+
+			body := []byte("hello world" + strconv.Itoa(i))
+			buf, err := Encode(serialNo, body)
+			assert.Nil(t, err)
+			assert.Nil(t, mc.Write(buf))
+
+			rsp, err := mc.Read()
+			assert.Nil(t, err)
+			assert.Equal(t, serialNo, rsp.GetSerialNo())
+			assert.Equal(t, body, rsp.GetResponseBuf())
+		}(i)
+	}
+	wg.Wait()
+}