You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2023/06/15 14:54:05 UTC

[plc4x] 02/02: fix(plc4go/cbus): fix some concurrency issue when closing the codec

This is an automated email from the ASF dual-hosted git repository.

sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git

commit 1797abebc554422431e808f91a0a55efc0e8896f
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Thu Jun 15 16:53:56 2023 +0200

    fix(plc4go/cbus): fix some concurrency issue when closing the codec
---
 plc4go/internal/cbus/Browser_test.go      |  2 +-
 plc4go/internal/cbus/Connection.go        | 20 +++++++++++++++++-
 plc4go/internal/cbus/Discoverer.go        |  2 +-
 plc4go/internal/cbus/MessageCodec.go      | 34 +++++++++++++++++++++++++++++--
 plc4go/internal/cbus/MessageCodec_test.go | 17 ++++++++++++----
 5 files changed, 66 insertions(+), 9 deletions(-)

diff --git a/plc4go/internal/cbus/Browser_test.go b/plc4go/internal/cbus/Browser_test.go
index 5f4c0c7529..ba70fc1bea 100644
--- a/plc4go/internal/cbus/Browser_test.go
+++ b/plc4go/internal/cbus/Browser_test.go
@@ -438,7 +438,7 @@ func TestBrowser_getInstalledUnitAddressBytes(t *testing.T) {
 					select {
 					case <-fields.connection.Close():
 					case <-timer.C:
-						t.Error("timeout")
+						t.Error("timeout waiting for connection close")
 					}
 				})
 			},
diff --git a/plc4go/internal/cbus/Connection.go b/plc4go/internal/cbus/Connection.go
index cb2655a389..7a1652e17c 100644
--- a/plc4go/internal/cbus/Connection.go
+++ b/plc4go/internal/cbus/Connection.go
@@ -69,6 +69,8 @@ type Connection struct {
 	configuration Configuration `stringer:"true"`
 	driverContext DriverContext `stringer:"true"`
 
+	handlerWaitGroup sync.WaitGroup
+
 	connectionId string
 	tracer       tracer.Tracer
 
@@ -129,7 +131,7 @@ func (c *Connection) ConnectWithContext(ctx context.Context) <-chan plc4go.PlcCo
 				c.fireConnectionError(errors.Errorf("panic-ed %v. Stack:\n%s", err, debug.Stack()), ch)
 			}
 		}()
-		if err := c.messageCodec.Connect(); err != nil {
+		if err := c.messageCodec.ConnectWithContext(ctx); err != nil {
 			c.fireConnectionError(errors.Wrap(err, "Error connecting codec"), ch)
 			return
 		}
@@ -150,6 +152,18 @@ func (c *Connection) ConnectWithContext(ctx context.Context) <-chan plc4go.PlcCo
 	return ch
 }
 
+func (c *Connection) Close() <-chan plc4go.PlcConnectionCloseResult {
+	results := make(chan plc4go.PlcConnectionCloseResult, 1)
+	go func() {
+		result := <-c.DefaultConnection.Close()
+		c.log.Trace().Msg("Waiting for handlers to stop")
+		c.handlerWaitGroup.Wait()
+		c.log.Trace().Msg("handlers stopped, dispatching result")
+		results <- result
+	}()
+	return results
+}
+
 func (c *Connection) GetMetadata() apiModel.PlcConnectionMetadata {
 	return _default.DefaultConnectionMetadata{
 		ProvidesReading:     true,
@@ -231,7 +245,9 @@ func (c *Connection) setupConnection(ctx context.Context, ch chan plc4go.PlcConn
 
 func (c *Connection) startSubscriptionHandler() {
 	c.log.Debug().Msg("Starting SAL handler")
+	c.handlerWaitGroup.Add(1)
 	go func() {
+		defer c.handlerWaitGroup.Done()
 		defer func() {
 			if err := recover(); err != nil {
 				c.log.Error().Msgf("panic-ed %v. Stack:\n%s", err, debug.Stack())
@@ -255,7 +271,9 @@ func (c *Connection) startSubscriptionHandler() {
 		c.log.Info().Msg("Ending SAL handler")
 	}()
 	c.log.Debug().Msg("Starting MMI handler")
+	c.handlerWaitGroup.Add(1)
 	go func() {
+		defer c.handlerWaitGroup.Done()
 		defer func() {
 			if err := recover(); err != nil {
 				c.log.Error().Msgf("panic-ed %v. Stack:\n%s", err, debug.Stack())
diff --git a/plc4go/internal/cbus/Discoverer.go b/plc4go/internal/cbus/Discoverer.go
index 6e06869113..a3648227a9 100644
--- a/plc4go/internal/cbus/Discoverer.go
+++ b/plc4go/internal/cbus/Discoverer.go
@@ -210,7 +210,7 @@ func (d *Discoverer) createDeviceScanDispatcher(tcpTransportInstance *tcp.Transp
 		// Create a codec for sending and receiving messages.
 		codec := NewMessageCodec(tcpTransportInstance, options.WithCustomLogger(d.log))
 		// Explicitly start the worker
-		if err := codec.Connect(); err != nil {
+		if err := codec.ConnectWithContext(context.TODO()); err != nil {
 			transportInstanceLogger.Debug().Err(err).Msg("Error connecting")
 			return
 		}
diff --git a/plc4go/internal/cbus/MessageCodec.go b/plc4go/internal/cbus/MessageCodec.go
index ce5504e39d..831c13ef86 100644
--- a/plc4go/internal/cbus/MessageCodec.go
+++ b/plc4go/internal/cbus/MessageCodec.go
@@ -22,6 +22,7 @@ package cbus
 import (
 	"bufio"
 	"context"
+	"sync"
 	"sync/atomic"
 
 	readWriteModel "github.com/apache/plc4x/plc4go/protocols/cbus/readwrite/model"
@@ -49,6 +50,8 @@ type MessageCodec struct {
 
 	currentlyReportedServerErrors atomic.Uint64
 
+	stateChange sync.Mutex
+
 	passLogToModel bool           `ignore:"true"`
 	log            zerolog.Logger `ignore:"true"`
 }
@@ -57,8 +60,6 @@ func NewMessageCodec(transportInstance transports.TransportInstance, _options ..
 	codec := &MessageCodec{
 		requestContext: readWriteModel.NewRequestContext(false),
 		cbusOptions:    readWriteModel.NewCBusOptions(false, false, false, false, false, false, false, false, false),
-		monitoredMMIs:  make(chan readWriteModel.CALReply, 100),
-		monitoredSALs:  make(chan readWriteModel.MonitoredSAL, 100),
 		passLogToModel: options.ExtractPassLoggerToModel(_options...),
 		log:            options.ExtractCustomLogger(_options...),
 	}
@@ -70,6 +71,35 @@ func (m *MessageCodec) GetCodec() spi.MessageCodec {
 	return m
 }
 
+func (m *MessageCodec) Connect() error {
+	return m.ConnectWithContext(context.Background())
+}
+
+func (m *MessageCodec) ConnectWithContext(ctx context.Context) error {
+	m.stateChange.Lock()
+	defer m.stateChange.Unlock()
+	if m.IsRunning() {
+		return errors.New("already running")
+	}
+	m.log.Trace().Msg("building channels")
+	m.monitoredMMIs = make(chan readWriteModel.CALReply, 100)
+	m.monitoredSALs = make(chan readWriteModel.MonitoredSAL, 100)
+	return m.DefaultCodec.ConnectWithContext(ctx)
+}
+
+func (m *MessageCodec) Disconnect() error {
+	m.stateChange.Lock()
+	defer m.stateChange.Unlock()
+	if !m.IsRunning() {
+		return errors.New("already disconnected")
+	}
+	err := m.DefaultCodec.Disconnect()
+	m.log.Trace().Msg("closing channels")
+	close(m.monitoredMMIs)
+	close(m.monitoredSALs)
+	return err
+}
+
 func (m *MessageCodec) Send(message spi.Message) error {
 	m.log.Trace().Msg("Sending message")
 	// Cast the message to the correct type of struct
diff --git a/plc4go/internal/cbus/MessageCodec_test.go b/plc4go/internal/cbus/MessageCodec_test.go
index 142e26b2cc..91db4635ad 100644
--- a/plc4go/internal/cbus/MessageCodec_test.go
+++ b/plc4go/internal/cbus/MessageCodec_test.go
@@ -749,9 +749,10 @@ func Test_extractMMIAndSAL(t *testing.T) {
 		message spi.Message
 	}
 	tests := []struct {
-		name string
-		args args
-		want bool
+		name  string
+		args  args
+		setup func(t *testing.T, args *args)
+		want  bool
 	}{
 		{
 			name: "extract it",
@@ -759,7 +760,6 @@ func Test_extractMMIAndSAL(t *testing.T) {
 		{
 			name: "monitored sal",
 			args: args{
-				codec: NewMessageCodec(nil),
 				message: readWriteModel.NewCBusMessageToClient(
 					readWriteModel.NewReplyOrConfirmationReply(
 						readWriteModel.NewReplyEncodedReply(
@@ -783,10 +783,19 @@ func Test_extractMMIAndSAL(t *testing.T) {
 					nil,
 				),
 			},
+			setup: func(t *testing.T, args *args) {
+				_options := testutils.EnrichOptionsWithOptionsForTesting(t)
+				codec := NewMessageCodec(nil, _options...)
+				codec.monitoredSALs = make(chan readWriteModel.MonitoredSAL, 1)
+				args.codec = codec
+			},
 		},
 	}
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
+			if tt.setup != nil {
+				tt.setup(t, &tt.args)
+			}
 			assert.Equalf(t, tt.want, extractMMIAndSAL(testutils.ProduceTestingLogger(t))(tt.args.codec, tt.args.message), "extractMMIAndSAL(%v, %v)", tt.args.codec, tt.args.message)
 		})
 	}