You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by mm...@apache.org on 2019/10/28 22:10:04 UTC

[pulsar-client-go] branch master updated: [ISSUE #72] Fix data race conditions. (#77)

This is an automated email from the ASF dual-hosted git repository.

mmerli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new bd30a32  [ISSUE #72] Fix data race conditions. (#77)
bd30a32 is described below

commit bd30a324bb30bb7a049cb44576f392cc2ed01575
Author: cckellogg <cc...@gmail.com>
AuthorDate: Mon Oct 28 15:09:59 2019 -0700

    [ISSUE #72] Fix data race conditions. (#77)
    
    * [ISSUE #72] Fix data race conditions.
    
    * Remove commented out code.
    
    * Revert write request ch to unbuffered.
---
 license_test.go                                    |   2 +-
 pulsar/impl_consumer.go                            |  10 +-
 pulsar/impl_partition_producer.go                  |  19 ++-
 pulsar/internal/connection.go                      | 150 +++++++++++++--------
 ...unackedMsgTracker.go => unacked_msg_tracker.go} |   6 +-
 ...Tracker_test.go => unacked_msg_tracker_test.go} |   3 +-
 6 files changed, 119 insertions(+), 71 deletions(-)

diff --git a/license_test.go b/license_test.go
index 51b7f44..84998c5 100644
--- a/license_test.go
+++ b/license_test.go
@@ -65,7 +65,7 @@ var otherCheck = regexp.MustCompile(`#
 `)
 
 var skip = map[string]bool{
-	"pkg/pb/PulsarApi.pb.go":true,
+	"pkg/pb/PulsarApi.pb.go": true,
 }
 
 func TestLicense(t *testing.T) {
diff --git a/pulsar/impl_consumer.go b/pulsar/impl_consumer.go
index 0178943..6a196ad 100644
--- a/pulsar/impl_consumer.go
+++ b/pulsar/impl_consumer.go
@@ -98,10 +98,12 @@ func singleTopicSubscribe(client *client, options *ConsumerOptions, topic string
 	ch := make(chan ConsumerError, numPartitions)
 
 	for partitionIdx, partitionTopic := range partitions {
+		// this needs to be created outside in the same go routine since
+		// newPartitionConsumer can modify the shared options struct causing a race condition
+		cons, err := newPartitionConsumer(client, partitionTopic, options, partitionIdx, numPartitions, c.queue)
 		go func(partitionIdx int, partitionTopic string) {
-			cons, e := newPartitionConsumer(client, partitionTopic, options, partitionIdx, numPartitions, c.queue)
 			ch <- ConsumerError{
-				err:       e,
+				err:       err,
 				partition: partitionIdx,
 				cons:      cons,
 			}
@@ -141,8 +143,8 @@ func (c *consumer) Subscription() string {
 
 func (c *consumer) Unsubscribe() error {
 	var errMsg string
-	for _, c := range c.consumers {
-		if err := c.Unsubscribe(); err != nil {
+	for _, consumer := range c.consumers {
+		if err := consumer.Unsubscribe(); err != nil {
 			errMsg += fmt.Sprintf("topic %s, subscription %s: %s", c.Topic(), c.Subscription(), err)
 		}
 	}
diff --git a/pulsar/impl_partition_producer.go b/pulsar/impl_partition_producer.go
index f09cd42..196dbbc 100644
--- a/pulsar/impl_partition_producer.go
+++ b/pulsar/impl_partition_producer.go
@@ -25,11 +25,11 @@ import (
 
 	"github.com/golang/protobuf/proto"
 
+	log "github.com/sirupsen/logrus"
+
 	"github.com/apache/pulsar-client-go/pkg/pb"
 	"github.com/apache/pulsar-client-go/pulsar/internal"
 	"github.com/apache/pulsar-client-go/util"
-
-	log "github.com/sirupsen/logrus"
 )
 
 type producerState int
@@ -272,6 +272,7 @@ func (p *partitionProducer) internalSend(request *sendRequest) {
 }
 
 type pendingItem struct {
+	sync.Mutex
 	batchData    []byte
 	sequenceID   uint64
 	sendRequests []interface{}
@@ -300,13 +301,19 @@ func (p *partitionProducer) internalFlush(fr *flushRequest) {
 		return
 	}
 
-	pi.sendRequests = append(pi.sendRequests, &sendRequest{
+	sendReq := &sendRequest{
 		msg: nil,
 		callback: func(id MessageID, message *ProducerMessage, e error) {
 			fr.err = e
 			fr.waitGroup.Done()
 		},
-	})
+	}
+
+	// lock the pending request while adding requests
+	// since the ReceivedSendReceipt func iterates over this list
+	pi.Lock()
+	pi.sendRequests = append(pi.sendRequests, sendReq)
+	pi.Unlock()
 }
 
 func (p *partitionProducer) Send(ctx context.Context, msg *ProducerMessage) error {
@@ -370,6 +377,10 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt)
 
 	// The ack was indeed for the expected item in the queue, we can remove it and trigger the callback
 	p.pendingQueue.Poll()
+
+	// lock the pending item while sending the requests
+	pi.Lock()
+	defer pi.Unlock()
 	for idx, i := range pi.sendRequests {
 		sr := i.(*sendRequest)
 		if sr.msg != nil {
diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go
index fe43f79..592beca 100644
--- a/pulsar/internal/connection.go
+++ b/pulsar/internal/connection.go
@@ -81,6 +81,23 @@ const (
 	connectionClosed
 )
 
+func (s connectionState) String() string {
+	switch s {
+	case connectionInit:
+		return "Initializing"
+	case connectionConnecting:
+		return "Connecting"
+	case connectionTCPConnected:
+		return "TCPConnected"
+	case connectionReady:
+		return "Ready"
+	case connectionClosed:
+		return "Closed"
+	default:
+		return "Unknown"
+	}
+}
+
 const keepAliveInterval = 30 * time.Second
 
 type request struct {
@@ -98,8 +115,11 @@ type connection struct {
 	physicalAddr *url.URL
 	cnx          net.Conn
 
+	writeBufferLock sync.Mutex
 	writeBuffer          Buffer
 	reader               *connectionReader
+
+	lastDataReceivedLock sync.Mutex
 	lastDataReceivedTime time.Time
 	pingTicker           *time.Ticker
 
@@ -107,13 +127,15 @@ type connection struct {
 
 	requestIDGenerator uint64
 
-	incomingRequests chan *request
-	writeRequests    chan []byte
+	incomingRequestsCh chan *request
+	writeRequestsCh    chan []byte
 
 	mapMutex    sync.RWMutex
 	pendingReqs map[uint64]*request
 	listeners   map[uint64]ConnectionListener
-	connWrapper *ConnWrapper
+
+	consumerHandlersLock sync.RWMutex
+	consumerHandlers map[uint64]ConsumerHandler
 
 	tlsOptions *TLSOptions
 	auth       auth.Provider
@@ -125,17 +147,17 @@ func newConnection(logicalAddr *url.URL, physicalAddr *url.URL, tlsOptions *TLSO
 		logicalAddr:          logicalAddr,
 		physicalAddr:         physicalAddr,
 		writeBuffer:          NewBuffer(4096),
-		log:                  log.WithField("raddr", physicalAddr),
+		log:                  log.WithField("remote_addr", physicalAddr),
 		pendingReqs:          make(map[uint64]*request),
 		lastDataReceivedTime: time.Now(),
 		pingTicker:           time.NewTicker(keepAliveInterval),
 		tlsOptions:           tlsOptions,
 		auth:                 auth,
 
-		incomingRequests: make(chan *request),
-		writeRequests:    make(chan []byte),
-		listeners:        make(map[uint64]ConnectionListener),
-		connWrapper:      NewConnWrapper(),
+		incomingRequestsCh: make(chan *request),
+		writeRequestsCh:    make(chan []byte),
+		listeners:          make(map[uint64]ConnectionListener),
+		consumerHandlers:   make(map[uint64]ConsumerHandler),
 	}
 	cnx.reader = newConnectionReader(cnx)
 	cnx.cond = sync.NewCond(cnx)
@@ -157,7 +179,7 @@ func (c *connection) start() {
 	}()
 }
 
-func (c *connection) connect() (ok bool) {
+func (c *connection) connect() bool {
 	c.log.Info("Connecting to broker")
 
 	var (
@@ -185,15 +207,19 @@ func (c *connection) connect() (ok bool) {
 		c.Close()
 		return false
 	}
+
+	c.Lock()
 	c.cnx = cnx
-	c.log = c.log.WithField("laddr", c.cnx.LocalAddr())
-	c.log.Debug("TCP connection established")
-	c.state = connectionTCPConnected
+	c.log = c.log.WithField("local_addr", c.cnx.LocalAddr())
+	c.log.Info("TCP connection established")
+	c.Unlock()
+
+	c.changeState(connectionTCPConnected)
 
 	return true
 }
 
-func (c *connection) doHandshake() (ok bool) {
+func (c *connection) doHandshake() bool {
 	// Send 'Connect' command to initiate handshake
 	version := int32(pb.ProtocolVersion_v13)
 
@@ -231,24 +257,16 @@ func (c *connection) waitUntilReady() error {
 	c.Lock()
 	defer c.Unlock()
 
-	for {
+	for c.state != connectionReady {
 		c.log.Debug("Wait until connection is ready. State: ", c.state)
-		switch c.state {
-		case connectionInit:
-			fallthrough
-		case connectionConnecting:
-			fallthrough
-		case connectionTCPConnected:
-			// Wait for the state to change
-			c.cond.Wait()
-
-		case connectionReady:
-			return nil
-
-		case connectionClosed:
+		if c.state == connectionClosed {
 			return errors.New("connection error")
 		}
+		// wait for a new connection state change
+		c.cond.Wait()
 	}
+
+	return nil
 }
 
 func (c *connection) run() {
@@ -257,7 +275,7 @@ func (c *connection) run() {
 
 	for {
 		select {
-		case req := <-c.incomingRequests:
+		case req := <-c.incomingRequestsCh:
 			if req == nil {
 				return
 			}
@@ -266,7 +284,7 @@ func (c *connection) run() {
 			c.mapMutex.Unlock()
 			c.writeCommand(req.cmd)
 
-		case data := <-c.writeRequests:
+		case data := <-c.writeRequestsCh:
 			if data == nil {
 				return
 			}
@@ -279,7 +297,7 @@ func (c *connection) run() {
 }
 
 func (c *connection) WriteData(data []byte) {
-	c.writeRequests <- data
+	c.writeRequestsCh <- data
 }
 
 func (c *connection) internalWriteData(data []byte) {
@@ -296,6 +314,9 @@ func (c *connection) writeCommand(cmd proto.Message) {
 	cmdSize := uint32(proto.Size(cmd))
 	frameSize := cmdSize + 4
 
+	c.writeBufferLock.Lock()
+	defer c.writeBufferLock.Unlock()
+
 	c.writeBuffer.Clear()
 	c.writeBuffer.WriteUint32(frameSize)
 	c.writeBuffer.WriteUint32(cmdSize)
@@ -305,12 +326,13 @@ func (c *connection) writeCommand(cmd proto.Message) {
 	}
 
 	c.writeBuffer.Write(serialized)
-	c.internalWriteData(c.writeBuffer.ReadableSlice())
+	data := c.writeBuffer.ReadableSlice()
+	c.internalWriteData(data)
 }
 
 func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload []byte) {
 	c.log.Debugf("Received command: %s -- payload: %v", cmd, headersAndPayload)
-	c.lastDataReceivedTime = time.Now()
+	c.setLastDataReceived(time.Now())
 	var err error
 
 	switch *cmd.Type {
@@ -374,11 +396,11 @@ func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload []by
 }
 
 func (c *connection) Write(data []byte) {
-	c.writeRequests <- data
+	c.writeRequestsCh <- data
 }
 
 func (c *connection) SendRequest(requestID uint64, req *pb.BaseCommand, callback func(command *pb.BaseCommand)) {
-	c.incomingRequests <- &request{
+	c.incomingRequestsCh <- &request{
 		id:       requestID,
 		cmd:      req,
 		callback: callback,
@@ -406,7 +428,6 @@ func (c *connection) handleResponse(requestID uint64, response *pb.BaseCommand)
 }
 
 func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) {
-	c.log.Debug("Got SEND_RECEIPT: ", response)
 	producerID := response.GetProducerId()
 	if producer, ok := c.listeners[producerID]; ok {
 		producer.ReceivedSendReceipt(response)
@@ -418,7 +439,7 @@ func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) {
 func (c *connection) handleMessage(response *pb.CommandMessage, payload []byte) error {
 	c.log.Debug("Got Message: ", response)
 	consumerID := response.GetConsumerId()
-	if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+	if consumer, ok := c.consumerHandler(consumerID); ok {
 		err := consumer.MessageReceived(response, payload)
 		if err != nil {
 			c.log.WithField("consumerID", consumerID).Error("handle message err: ", response.MessageId)
@@ -430,8 +451,21 @@ func (c *connection) handleMessage(response *pb.CommandMessage, payload []byte)
 	return nil
 }
 
+func (c *connection) lastDataReceived() time.Time {
+	c.lastDataReceivedLock.Lock()
+	defer c.lastDataReceivedLock.Unlock()
+	t := c.lastDataReceivedTime
+	return t;
+}
+
+func (c *connection) setLastDataReceived(t time.Time) {
+	c.lastDataReceivedLock.Lock()
+	defer c.lastDataReceivedLock.Unlock()
+	c.lastDataReceivedTime = t
+}
+
 func (c *connection) sendPing() {
-	if c.lastDataReceivedTime.Add(2 * keepAliveInterval).Before(time.Now()) {
+	if c.lastDataReceived().Add(2 * keepAliveInterval).Before(time.Now()) {
 		// We have not received a response to the previous Ping request, the
 		// connection to broker is stale
 		c.log.Info("Detected stale connection to broker")
@@ -454,7 +488,7 @@ func (c *connection) handlePing() {
 func (c *connection) handleCloseConsumer(closeConsumer *pb.CommandCloseConsumer) {
 	c.log.Infof("Broker notification of Closed consumer: %d", closeConsumer.GetConsumerId())
 	consumerID := closeConsumer.GetConsumerId()
-	if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+	if consumer, ok := c.consumerHandler(consumerID); ok {
 		if !util.IsNil(consumer) {
 			consumer.ConnectionClosed()
 		}
@@ -503,15 +537,17 @@ func (c *connection) Close() {
 		c.cnx.Close()
 	}
 	c.pingTicker.Stop()
-	close(c.incomingRequests)
-	close(c.writeRequests)
+	close(c.incomingRequestsCh)
+	close(c.writeRequestsCh)
 
 	for _, listener := range c.listeners {
 		listener.ConnectionClosed()
 	}
 
-	for _, cnx := range c.connWrapper.Consumers {
-		cnx.ConnectionClosed()
+	c.consumerHandlersLock.RLock()
+	defer c.consumerHandlersLock.RUnlock()
+	for _, handler := range c.consumerHandlers {
+		handler.ConnectionClosed()
 	}
 }
 
@@ -560,25 +596,21 @@ func (c *connection) getTLSConfig() (*tls.Config, error) {
 	return tlsConfig, nil
 }
 
-type ConnWrapper struct {
-	Rwmu      sync.RWMutex
-	Consumers map[uint64]ConsumerHandler
-}
-
-func NewConnWrapper() *ConnWrapper {
-	return &ConnWrapper{
-		Consumers: make(map[uint64]ConsumerHandler),
-	}
-}
-
 func (c *connection) AddConsumeHandler(id uint64, handler ConsumerHandler) {
-	c.connWrapper.Rwmu.Lock()
-	c.connWrapper.Consumers[id] = handler
-	c.connWrapper.Rwmu.Unlock()
+	c.consumerHandlersLock.Lock()
+	defer c.consumerHandlersLock.Unlock()
+	c.consumerHandlers[id] = handler
 }
 
 func (c *connection) DeleteConsumeHandler(id uint64) {
-	c.connWrapper.Rwmu.Lock()
-	delete(c.connWrapper.Consumers, id)
-	c.connWrapper.Rwmu.Unlock()
+	c.consumerHandlersLock.Lock()
+	defer c.consumerHandlersLock.Unlock()
+	delete(c.consumerHandlers, id)
+}
+
+func (c *connection) consumerHandler(id uint64) (ConsumerHandler, bool) {
+	c.consumerHandlersLock.RLock()
+	defer c.consumerHandlersLock.RUnlock()
+	h, ok := c.consumerHandlers[id]
+	return h, ok
 }
diff --git a/pulsar/unackedMsgTracker.go b/pulsar/unacked_msg_tracker.go
similarity index 99%
rename from pulsar/unackedMsgTracker.go
rename to pulsar/unacked_msg_tracker.go
index 09ff0cb..ffc6eff 100644
--- a/pulsar/unackedMsgTracker.go
+++ b/pulsar/unacked_msg_tracker.go
@@ -21,11 +21,12 @@ import (
 	"sync"
 	"time"
 
-	"github.com/apache/pulsar-client-go/pkg/pb"
 	"github.com/golang/protobuf/proto"
 
 	set "github.com/deckarep/golang-set"
 	log "github.com/sirupsen/logrus"
+
+	"github.com/apache/pulsar-client-go/pkg/pb"
 )
 
 type UnackedMessageTracker struct {
@@ -146,6 +147,7 @@ func (t *UnackedMessageTracker) handlerCmd() {
 		select {
 		case tick := <-t.timeout.C:
 			if t.isAckTimeout() {
+				t.cmu.Lock()
 				log.Debugf(" %d messages have timed-out", t.oldOpenSet.Cardinality())
 				messageIds := make([]*pb.MessageIdData, 0)
 
@@ -153,10 +155,10 @@ func (t *UnackedMessageTracker) handlerCmd() {
 					messageIds = append(messageIds, i.(*pb.MessageIdData))
 					return false
 				})
-
 				log.Debugf("messageID length is:%d", len(messageIds))
 
 				t.oldOpenSet.Clear()
+				t.cmu.Unlock()
 
 				if t.pcs != nil {
 					messageIdsMap := make(map[int32][]*pb.MessageIdData)
diff --git a/pulsar/unackMsgTracker_test.go b/pulsar/unacked_msg_tracker_test.go
similarity index 99%
rename from pulsar/unackMsgTracker_test.go
rename to pulsar/unacked_msg_tracker_test.go
index edf7ddc..3848ce9 100644
--- a/pulsar/unackMsgTracker_test.go
+++ b/pulsar/unacked_msg_tracker_test.go
@@ -20,9 +20,10 @@ package pulsar
 import (
 	"testing"
 
-	"github.com/apache/pulsar-client-go/pkg/pb"
 	"github.com/golang/protobuf/proto"
 	"github.com/stretchr/testify/assert"
+
+	"github.com/apache/pulsar-client-go/pkg/pb"
 )
 
 func TestUnackedMessageTracker(t *testing.T) {