You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pulsar.apache.org by mm...@apache.org on 2019/10/28 22:10:04 UTC
[pulsar-client-go] branch master updated: [ISSUE #72] Fix data race
conditions. (#77)
This is an automated email from the ASF dual-hosted git repository.
mmerli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new bd30a32 [ISSUE #72] Fix data race conditions. (#77)
bd30a32 is described below
commit bd30a324bb30bb7a049cb44576f392cc2ed01575
Author: cckellogg <cc...@gmail.com>
AuthorDate: Mon Oct 28 15:09:59 2019 -0700
[ISSUE #72] Fix data race conditions. (#77)
* [ISSUE #72] Fix data race conditions.
* Remove commented out code.
* Revert write request ch to unbuffered.
---
license_test.go | 2 +-
pulsar/impl_consumer.go | 10 +-
pulsar/impl_partition_producer.go | 19 ++-
pulsar/internal/connection.go | 150 +++++++++++++--------
...unackedMsgTracker.go => unacked_msg_tracker.go} | 6 +-
...Tracker_test.go => unacked_msg_tracker_test.go} | 3 +-
6 files changed, 119 insertions(+), 71 deletions(-)
diff --git a/license_test.go b/license_test.go
index 51b7f44..84998c5 100644
--- a/license_test.go
+++ b/license_test.go
@@ -65,7 +65,7 @@ var otherCheck = regexp.MustCompile(`#
`)
var skip = map[string]bool{
- "pkg/pb/PulsarApi.pb.go":true,
+ "pkg/pb/PulsarApi.pb.go": true,
}
func TestLicense(t *testing.T) {
diff --git a/pulsar/impl_consumer.go b/pulsar/impl_consumer.go
index 0178943..6a196ad 100644
--- a/pulsar/impl_consumer.go
+++ b/pulsar/impl_consumer.go
@@ -98,10 +98,12 @@ func singleTopicSubscribe(client *client, options *ConsumerOptions, topic string
ch := make(chan ConsumerError, numPartitions)
for partitionIdx, partitionTopic := range partitions {
+ // this needs to be created outside in the same go routine since
+ // newPartitionConsumer can modify the shared options struct causing a race condition
+ cons, err := newPartitionConsumer(client, partitionTopic, options, partitionIdx, numPartitions, c.queue)
go func(partitionIdx int, partitionTopic string) {
- cons, e := newPartitionConsumer(client, partitionTopic, options, partitionIdx, numPartitions, c.queue)
ch <- ConsumerError{
- err: e,
+ err: err,
partition: partitionIdx,
cons: cons,
}
@@ -141,8 +143,8 @@ func (c *consumer) Subscription() string {
func (c *consumer) Unsubscribe() error {
var errMsg string
- for _, c := range c.consumers {
- if err := c.Unsubscribe(); err != nil {
+ for _, consumer := range c.consumers {
+ if err := consumer.Unsubscribe(); err != nil {
errMsg += fmt.Sprintf("topic %s, subscription %s: %s", c.Topic(), c.Subscription(), err)
}
}
diff --git a/pulsar/impl_partition_producer.go b/pulsar/impl_partition_producer.go
index f09cd42..196dbbc 100644
--- a/pulsar/impl_partition_producer.go
+++ b/pulsar/impl_partition_producer.go
@@ -25,11 +25,11 @@ import (
"github.com/golang/protobuf/proto"
+ log "github.com/sirupsen/logrus"
+
"github.com/apache/pulsar-client-go/pkg/pb"
"github.com/apache/pulsar-client-go/pulsar/internal"
"github.com/apache/pulsar-client-go/util"
-
- log "github.com/sirupsen/logrus"
)
type producerState int
@@ -272,6 +272,7 @@ func (p *partitionProducer) internalSend(request *sendRequest) {
}
type pendingItem struct {
+ sync.Mutex
batchData []byte
sequenceID uint64
sendRequests []interface{}
@@ -300,13 +301,19 @@ func (p *partitionProducer) internalFlush(fr *flushRequest) {
return
}
- pi.sendRequests = append(pi.sendRequests, &sendRequest{
+ sendReq := &sendRequest{
msg: nil,
callback: func(id MessageID, message *ProducerMessage, e error) {
fr.err = e
fr.waitGroup.Done()
},
- })
+ }
+
+ // lock the pending request while adding requests
+ // since the ReceivedSendReceipt func iterates over this list
+ pi.Lock()
+ pi.sendRequests = append(pi.sendRequests, sendReq)
+ pi.Unlock()
}
func (p *partitionProducer) Send(ctx context.Context, msg *ProducerMessage) error {
@@ -370,6 +377,10 @@ func (p *partitionProducer) ReceivedSendReceipt(response *pb.CommandSendReceipt)
// The ack was indeed for the expected item in the queue, we can remove it and trigger the callback
p.pendingQueue.Poll()
+
+ // lock the pending item while sending the requests
+ pi.Lock()
+ defer pi.Unlock()
for idx, i := range pi.sendRequests {
sr := i.(*sendRequest)
if sr.msg != nil {
diff --git a/pulsar/internal/connection.go b/pulsar/internal/connection.go
index fe43f79..592beca 100644
--- a/pulsar/internal/connection.go
+++ b/pulsar/internal/connection.go
@@ -81,6 +81,23 @@ const (
connectionClosed
)
+func (s connectionState) String() string {
+ switch s {
+ case connectionInit:
+ return "Initializing"
+ case connectionConnecting:
+ return "Connecting"
+ case connectionTCPConnected:
+ return "TCPConnected"
+ case connectionReady:
+ return "Ready"
+ case connectionClosed:
+ return "Closed"
+ default:
+ return "Unknown"
+ }
+}
+
const keepAliveInterval = 30 * time.Second
type request struct {
@@ -98,8 +115,11 @@ type connection struct {
physicalAddr *url.URL
cnx net.Conn
+ writeBufferLock sync.Mutex
writeBuffer Buffer
reader *connectionReader
+
+ lastDataReceivedLock sync.Mutex
lastDataReceivedTime time.Time
pingTicker *time.Ticker
@@ -107,13 +127,15 @@ type connection struct {
requestIDGenerator uint64
- incomingRequests chan *request
- writeRequests chan []byte
+ incomingRequestsCh chan *request
+ writeRequestsCh chan []byte
mapMutex sync.RWMutex
pendingReqs map[uint64]*request
listeners map[uint64]ConnectionListener
- connWrapper *ConnWrapper
+
+ consumerHandlersLock sync.RWMutex
+ consumerHandlers map[uint64]ConsumerHandler
tlsOptions *TLSOptions
auth auth.Provider
@@ -125,17 +147,17 @@ func newConnection(logicalAddr *url.URL, physicalAddr *url.URL, tlsOptions *TLSO
logicalAddr: logicalAddr,
physicalAddr: physicalAddr,
writeBuffer: NewBuffer(4096),
- log: log.WithField("raddr", physicalAddr),
+ log: log.WithField("remote_addr", physicalAddr),
pendingReqs: make(map[uint64]*request),
lastDataReceivedTime: time.Now(),
pingTicker: time.NewTicker(keepAliveInterval),
tlsOptions: tlsOptions,
auth: auth,
- incomingRequests: make(chan *request),
- writeRequests: make(chan []byte),
- listeners: make(map[uint64]ConnectionListener),
- connWrapper: NewConnWrapper(),
+ incomingRequestsCh: make(chan *request),
+ writeRequestsCh: make(chan []byte),
+ listeners: make(map[uint64]ConnectionListener),
+ consumerHandlers: make(map[uint64]ConsumerHandler),
}
cnx.reader = newConnectionReader(cnx)
cnx.cond = sync.NewCond(cnx)
@@ -157,7 +179,7 @@ func (c *connection) start() {
}()
}
-func (c *connection) connect() (ok bool) {
+func (c *connection) connect() bool {
c.log.Info("Connecting to broker")
var (
@@ -185,15 +207,19 @@ func (c *connection) connect() (ok bool) {
c.Close()
return false
}
+
+ c.Lock()
c.cnx = cnx
- c.log = c.log.WithField("laddr", c.cnx.LocalAddr())
- c.log.Debug("TCP connection established")
- c.state = connectionTCPConnected
+ c.log = c.log.WithField("local_addr", c.cnx.LocalAddr())
+ c.log.Info("TCP connection established")
+ c.Unlock()
+
+ c.changeState(connectionTCPConnected)
return true
}
-func (c *connection) doHandshake() (ok bool) {
+func (c *connection) doHandshake() bool {
// Send 'Connect' command to initiate handshake
version := int32(pb.ProtocolVersion_v13)
@@ -231,24 +257,16 @@ func (c *connection) waitUntilReady() error {
c.Lock()
defer c.Unlock()
- for {
+ for c.state != connectionReady {
c.log.Debug("Wait until connection is ready. State: ", c.state)
- switch c.state {
- case connectionInit:
- fallthrough
- case connectionConnecting:
- fallthrough
- case connectionTCPConnected:
- // Wait for the state to change
- c.cond.Wait()
-
- case connectionReady:
- return nil
-
- case connectionClosed:
+ if c.state == connectionClosed {
return errors.New("connection error")
}
+ // wait for a new connection state change
+ c.cond.Wait()
}
+
+ return nil
}
func (c *connection) run() {
@@ -257,7 +275,7 @@ func (c *connection) run() {
for {
select {
- case req := <-c.incomingRequests:
+ case req := <-c.incomingRequestsCh:
if req == nil {
return
}
@@ -266,7 +284,7 @@ func (c *connection) run() {
c.mapMutex.Unlock()
c.writeCommand(req.cmd)
- case data := <-c.writeRequests:
+ case data := <-c.writeRequestsCh:
if data == nil {
return
}
@@ -279,7 +297,7 @@ func (c *connection) run() {
}
func (c *connection) WriteData(data []byte) {
- c.writeRequests <- data
+ c.writeRequestsCh <- data
}
func (c *connection) internalWriteData(data []byte) {
@@ -296,6 +314,9 @@ func (c *connection) writeCommand(cmd proto.Message) {
cmdSize := uint32(proto.Size(cmd))
frameSize := cmdSize + 4
+ c.writeBufferLock.Lock()
+ defer c.writeBufferLock.Unlock()
+
c.writeBuffer.Clear()
c.writeBuffer.WriteUint32(frameSize)
c.writeBuffer.WriteUint32(cmdSize)
@@ -305,12 +326,13 @@ func (c *connection) writeCommand(cmd proto.Message) {
}
c.writeBuffer.Write(serialized)
- c.internalWriteData(c.writeBuffer.ReadableSlice())
+ data := c.writeBuffer.ReadableSlice()
+ c.internalWriteData(data)
}
func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload []byte) {
c.log.Debugf("Received command: %s -- payload: %v", cmd, headersAndPayload)
- c.lastDataReceivedTime = time.Now()
+ c.setLastDataReceived(time.Now())
var err error
switch *cmd.Type {
@@ -374,11 +396,11 @@ func (c *connection) receivedCommand(cmd *pb.BaseCommand, headersAndPayload []by
}
func (c *connection) Write(data []byte) {
- c.writeRequests <- data
+ c.writeRequestsCh <- data
}
func (c *connection) SendRequest(requestID uint64, req *pb.BaseCommand, callback func(command *pb.BaseCommand)) {
- c.incomingRequests <- &request{
+ c.incomingRequestsCh <- &request{
id: requestID,
cmd: req,
callback: callback,
@@ -406,7 +428,6 @@ func (c *connection) handleResponse(requestID uint64, response *pb.BaseCommand)
}
func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) {
- c.log.Debug("Got SEND_RECEIPT: ", response)
producerID := response.GetProducerId()
if producer, ok := c.listeners[producerID]; ok {
producer.ReceivedSendReceipt(response)
@@ -418,7 +439,7 @@ func (c *connection) handleSendReceipt(response *pb.CommandSendReceipt) {
func (c *connection) handleMessage(response *pb.CommandMessage, payload []byte) error {
c.log.Debug("Got Message: ", response)
consumerID := response.GetConsumerId()
- if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+ if consumer, ok := c.consumerHandler(consumerID); ok {
err := consumer.MessageReceived(response, payload)
if err != nil {
c.log.WithField("consumerID", consumerID).Error("handle message err: ", response.MessageId)
@@ -430,8 +451,21 @@ func (c *connection) handleMessage(response *pb.CommandMessage, payload []byte)
return nil
}
+func (c *connection) lastDataReceived() time.Time {
+ c.lastDataReceivedLock.Lock()
+ defer c.lastDataReceivedLock.Unlock()
+ t := c.lastDataReceivedTime
+ return t;
+}
+
+func (c *connection) setLastDataReceived(t time.Time) {
+ c.lastDataReceivedLock.Lock()
+ defer c.lastDataReceivedLock.Unlock()
+ c.lastDataReceivedTime = t
+}
+
func (c *connection) sendPing() {
- if c.lastDataReceivedTime.Add(2 * keepAliveInterval).Before(time.Now()) {
+ if c.lastDataReceived().Add(2 * keepAliveInterval).Before(time.Now()) {
// We have not received a response to the previous Ping request, the
// connection to broker is stale
c.log.Info("Detected stale connection to broker")
@@ -454,7 +488,7 @@ func (c *connection) handlePing() {
func (c *connection) handleCloseConsumer(closeConsumer *pb.CommandCloseConsumer) {
c.log.Infof("Broker notification of Closed consumer: %d", closeConsumer.GetConsumerId())
consumerID := closeConsumer.GetConsumerId()
- if consumer, ok := c.connWrapper.Consumers[consumerID]; ok {
+ if consumer, ok := c.consumerHandler(consumerID); ok {
if !util.IsNil(consumer) {
consumer.ConnectionClosed()
}
@@ -503,15 +537,17 @@ func (c *connection) Close() {
c.cnx.Close()
}
c.pingTicker.Stop()
- close(c.incomingRequests)
- close(c.writeRequests)
+ close(c.incomingRequestsCh)
+ close(c.writeRequestsCh)
for _, listener := range c.listeners {
listener.ConnectionClosed()
}
- for _, cnx := range c.connWrapper.Consumers {
- cnx.ConnectionClosed()
+ c.consumerHandlersLock.RLock()
+ defer c.consumerHandlersLock.RUnlock()
+ for _, handler := range c.consumerHandlers {
+ handler.ConnectionClosed()
}
}
@@ -560,25 +596,21 @@ func (c *connection) getTLSConfig() (*tls.Config, error) {
return tlsConfig, nil
}
-type ConnWrapper struct {
- Rwmu sync.RWMutex
- Consumers map[uint64]ConsumerHandler
-}
-
-func NewConnWrapper() *ConnWrapper {
- return &ConnWrapper{
- Consumers: make(map[uint64]ConsumerHandler),
- }
-}
-
func (c *connection) AddConsumeHandler(id uint64, handler ConsumerHandler) {
- c.connWrapper.Rwmu.Lock()
- c.connWrapper.Consumers[id] = handler
- c.connWrapper.Rwmu.Unlock()
+ c.consumerHandlersLock.Lock()
+ defer c.consumerHandlersLock.Unlock()
+ c.consumerHandlers[id] = handler
}
func (c *connection) DeleteConsumeHandler(id uint64) {
- c.connWrapper.Rwmu.Lock()
- delete(c.connWrapper.Consumers, id)
- c.connWrapper.Rwmu.Unlock()
+ c.consumerHandlersLock.Lock()
+ defer c.consumerHandlersLock.Unlock()
+ delete(c.consumerHandlers, id)
+}
+
+func (c *connection) consumerHandler(id uint64) (ConsumerHandler, bool) {
+ c.consumerHandlersLock.RLock()
+ defer c.consumerHandlersLock.RUnlock()
+ h, ok := c.consumerHandlers[id]
+ return h, ok
}
diff --git a/pulsar/unackedMsgTracker.go b/pulsar/unacked_msg_tracker.go
similarity index 99%
rename from pulsar/unackedMsgTracker.go
rename to pulsar/unacked_msg_tracker.go
index 09ff0cb..ffc6eff 100644
--- a/pulsar/unackedMsgTracker.go
+++ b/pulsar/unacked_msg_tracker.go
@@ -21,11 +21,12 @@ import (
"sync"
"time"
- "github.com/apache/pulsar-client-go/pkg/pb"
"github.com/golang/protobuf/proto"
set "github.com/deckarep/golang-set"
log "github.com/sirupsen/logrus"
+
+ "github.com/apache/pulsar-client-go/pkg/pb"
)
type UnackedMessageTracker struct {
@@ -146,6 +147,7 @@ func (t *UnackedMessageTracker) handlerCmd() {
select {
case tick := <-t.timeout.C:
if t.isAckTimeout() {
+ t.cmu.Lock()
log.Debugf(" %d messages have timed-out", t.oldOpenSet.Cardinality())
messageIds := make([]*pb.MessageIdData, 0)
@@ -153,10 +155,10 @@ func (t *UnackedMessageTracker) handlerCmd() {
messageIds = append(messageIds, i.(*pb.MessageIdData))
return false
})
-
log.Debugf("messageID length is:%d", len(messageIds))
t.oldOpenSet.Clear()
+ t.cmu.Unlock()
if t.pcs != nil {
messageIdsMap := make(map[int32][]*pb.MessageIdData)
diff --git a/pulsar/unackMsgTracker_test.go b/pulsar/unacked_msg_tracker_test.go
similarity index 99%
rename from pulsar/unackMsgTracker_test.go
rename to pulsar/unacked_msg_tracker_test.go
index edf7ddc..3848ce9 100644
--- a/pulsar/unackMsgTracker_test.go
+++ b/pulsar/unacked_msg_tracker_test.go
@@ -20,9 +20,10 @@ package pulsar
import (
"testing"
- "github.com/apache/pulsar-client-go/pkg/pb"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
+
+ "github.com/apache/pulsar-client-go/pkg/pb"
)
func TestUnackedMessageTracker(t *testing.T) {