You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2022/08/17 11:23:31 UTC

[plc4x] 04/04: fix(plc4go/cbus): avoid channel leak by adding wait groups

This is an automated email from the ASF dual-hosted git repository.

sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git

commit ce2db6b314c397ffa092aa4046941a6813d3a798
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Wed Aug 17 13:23:19 2022 +0200

    fix(plc4go/cbus): avoid channel leak by adding wait groups
---
 plc4go/internal/cbus/Discoverer.go                 | 207 ++++++++++++---------
 .../tests/drivers/tests/manual_cbus_driver_test.go |  21 +++
 2 files changed, 135 insertions(+), 93 deletions(-)

diff --git a/plc4go/internal/cbus/Discoverer.go b/plc4go/internal/cbus/Discoverer.go
index 731c2b4a9..3ef576304 100644
--- a/plc4go/internal/cbus/Discoverer.go
+++ b/plc4go/internal/cbus/Discoverer.go
@@ -25,6 +25,7 @@ import (
 	"github.com/apache/plc4x/plc4go/spi/transports/tcp"
 	"net"
 	"net/url"
+	"sync"
 	"time"
 
 	apiModel "github.com/apache/plc4x/plc4go/pkg/api/model"
@@ -73,13 +74,16 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
 	}
 
 	transportInstances := make(chan transports.TransportInstance)
+	wg := &sync.WaitGroup{}
 	// Iterate over all network devices of this system.
 	for _, netInterface := range interfaces {
 		addrs, err := netInterface.Addrs()
 		if err != nil {
 			return err
 		}
+		wg.Add(1)
 		go func(netInterface net.Interface) {
+			defer func() { wg.Done() }()
 			// Iterate over all addresses the current interface has configured
 			// For KNX we're only interested in IPv4 addresses, as it doesn't
 			// seem to work with IPv6.
@@ -106,24 +110,34 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
 					log.Warn().Err(err).Msgf("Can't get addresses for %v", netInterface)
 					continue
 				}
+				wg.Add(1)
 				go func() {
+					defer func() { wg.Done() }()
 					for ip := range addresses {
+						log.Trace().Msgf("Handling found ip %v", ip)
+						wg.Add(1)
 						go func(ip net.IP) {
+							defer func() { wg.Done() }()
 							// Create a new "connection" (Actually open a local udp socket and target outgoing packets to that address)
-							connectionUrl, err := url.Parse(fmt.Sprintf("tcp://%s:%d", ip, readWriteModel.CBusConstants_CBUSTCPDEFAULTPORT))
-							if err != nil {
-								log.Error().Err(err).Msgf("Error parsing url for lookup")
-								return
+							var connectionUrl url.URL
+							{
+								connectionUrlParsed, err := url.Parse(fmt.Sprintf("tcp://%s:%d", ip, readWriteModel.CBusConstants_CBUSTCPDEFAULTPORT))
+								if err != nil {
+									log.Error().Err(err).Msgf("Error parsing url for lookup")
+									return
+								}
+								connectionUrl = *connectionUrlParsed
 							}
-							transportInstance, err := tcpTransport.CreateTransportInstance(*connectionUrl, nil)
+
+							transportInstance, err := tcpTransport.CreateTransportInstance(connectionUrl, nil)
 							if err != nil {
 								log.Error().Err(err).Msgf("Error creating transport instance")
 								return
 							}
-							log.Trace().Msgf("trying %s", connectionUrl)
-							err = transportInstance.Connect()
+							log.Trace().Msgf("trying %v", connectionUrl)
+							err = transportInstance.ConnectWithContext(ctx)
 							if err != nil {
-								secondErr := transportInstance.Connect()
+								secondErr := transportInstance.ConnectWithContext(ctx)
 								if secondErr != nil {
 									log.Trace().Err(err).Msgf("Error connecting transport instance")
 									return
@@ -137,99 +151,106 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
 			}
 		}(netInterface)
 	}
-
 	go func() {
-		for transportInstance := range transportInstances {
-			tcpTransportInstance := transportInstance.(*tcp.TransportInstance)
-			// Create a codec for sending and receiving messages.
-			codec := NewMessageCodec(transportInstance)
-			// Explicitly start the worker
-			if err := codec.Connect(); err != nil {
-				log.Debug().Err(err).Msg("Error connecting")
-				continue
-			}
+		wg.Wait()
+		log.Trace().Msg("Closing transport instance channel")
+		close(transportInstances)
+	}()
 
-			// Prepare the discovery packet data
-			cBusOptions := readWriteModel.NewCBusOptions(false, false, false, false, false, false, false, false, true)
-			requestContext := readWriteModel.NewRequestContext(false)
-			calData := readWriteModel.NewCALDataIdentify(readWriteModel.Attribute_Manufacturer, readWriteModel.CALCommandTypeContainer_CALCommandIdentify, nil, requestContext)
-			alpha := readWriteModel.NewAlpha('x')
-			request := readWriteModel.NewRequestDirectCommandAccess(calData, alpha, 0x0, nil, nil, readWriteModel.RequestType_DIRECT_COMMAND, readWriteModel.NewRequestTermination(), cBusOptions)
-			cBusMessageToServer := readWriteModel.NewCBusMessageToServer(request, requestContext, cBusOptions)
-			// Send the search request.
-			err = codec.Send(cBusMessageToServer)
-			go func() {
-				// Keep on reading responses till the timeout is done.
-				// TODO: Make this configurable
-				timeout := time.NewTimer(time.Second * 1)
-				timeout.Stop()
-				for start := time.Now(); time.Since(start) < time.Second*5; {
-					timeout.Reset(time.Second * 1)
-					select {
-					case receivedMessage := <-codec.GetDefaultIncomingMessageChannel():
-						if !timeout.Stop() {
-							<-timeout.C
-						}
-						cbusMessage, ok := receivedMessage.(readWriteModel.CBusMessage)
-						if !ok {
-							continue
-						}
-						messageToClient, ok := cbusMessage.(readWriteModel.CBusMessageToClient)
-						if !ok {
-							continue
-						}
-						replyOrConfirmationConfirmation, ok := messageToClient.GetReply().(readWriteModel.ReplyOrConfirmationConfirmationExactly)
-						if !ok {
-							continue
-						}
-						if receivedAlpha := replyOrConfirmationConfirmation.GetConfirmation().GetAlpha(); receivedAlpha != nil && alpha.GetCharacter() != receivedAlpha.GetCharacter() {
-							continue
-						}
-						embeddedReply, ok := replyOrConfirmationConfirmation.GetEmbeddedReply().(readWriteModel.ReplyOrConfirmationReplyExactly)
-						if !ok {
-							continue
-						}
-						encodedReply, ok := embeddedReply.GetReply().(readWriteModel.ReplyEncodedReplyExactly)
-						if !ok {
-							continue
-						}
-						encodedReplyCALReply, ok := encodedReply.GetEncodedReply().(readWriteModel.EncodedReplyCALReplyExactly)
-						if !ok {
-							continue
-						}
-						calDataIdentifyReply, ok := encodedReplyCALReply.GetCalReply().GetCalData().(readWriteModel.CALDataIdentifyReplyExactly)
-						if !ok {
-							continue
-						}
-						identifyReplyCommand, ok := calDataIdentifyReply.GetIdentifyReplyCommand().(readWriteModel.IdentifyReplyCommandManufacturerExactly)
-						if !ok {
-							continue
-						}
+	for transportInstance := range transportInstances {
+		tcpTransportInstance := transportInstance.(*tcp.TransportInstance)
+		// Create a codec for sending and receiving messages.
+		codec := NewMessageCodec(transportInstance)
+		// Explicitly start the worker
+		if err := codec.Connect(); err != nil {
+			log.Debug().Err(err).Msg("Error connecting")
+			continue
+		}
+
+		// Prepare the discovery packet data
+		cBusOptions := readWriteModel.NewCBusOptions(false, false, false, false, false, false, false, false, true)
+		requestContext := readWriteModel.NewRequestContext(false)
+		calData := readWriteModel.NewCALDataIdentify(readWriteModel.Attribute_Manufacturer, readWriteModel.CALCommandTypeContainer_CALCommandIdentify, nil, requestContext)
+		alpha := readWriteModel.NewAlpha('x')
+		request := readWriteModel.NewRequestDirectCommandAccess(calData, alpha, 0x0, nil, nil, readWriteModel.RequestType_DIRECT_COMMAND, readWriteModel.NewRequestTermination(), cBusOptions)
+		cBusMessageToServer := readWriteModel.NewCBusMessageToServer(request, requestContext, cBusOptions)
+		// Send the search request.
+		err = codec.Send(cBusMessageToServer)
+		go func() {
+			// Keep on reading responses till the timeout is done.
+			// TODO: Make this configurable
+			timeout := time.NewTimer(time.Second * 1)
+			timeout.Stop()
+			for start := time.Now(); time.Since(start) < time.Second*5; {
+				timeout.Reset(time.Second * 1)
+				select {
+				case receivedMessage := <-codec.GetDefaultIncomingMessageChannel():
+					if !timeout.Stop() {
+						<-timeout.C
+					}
+					cbusMessage, ok := receivedMessage.(readWriteModel.CBusMessage)
+					if !ok {
+						continue
+					}
+					messageToClient, ok := cbusMessage.(readWriteModel.CBusMessageToClient)
+					if !ok {
+						continue
+					}
+					replyOrConfirmationConfirmation, ok := messageToClient.GetReply().(readWriteModel.ReplyOrConfirmationConfirmationExactly)
+					if !ok {
+						continue
+					}
+					if receivedAlpha := replyOrConfirmationConfirmation.GetConfirmation().GetAlpha(); receivedAlpha != nil && alpha.GetCharacter() != receivedAlpha.GetCharacter() {
+						continue
+					}
+					embeddedReply, ok := replyOrConfirmationConfirmation.GetEmbeddedReply().(readWriteModel.ReplyOrConfirmationReplyExactly)
+					if !ok {
+						continue
+					}
+					encodedReply, ok := embeddedReply.GetReply().(readWriteModel.ReplyEncodedReplyExactly)
+					if !ok {
+						continue
+					}
+					encodedReplyCALReply, ok := encodedReply.GetEncodedReply().(readWriteModel.EncodedReplyCALReplyExactly)
+					if !ok {
+						continue
+					}
+					calDataIdentifyReply, ok := encodedReplyCALReply.GetCalReply().GetCalData().(readWriteModel.CALDataIdentifyReplyExactly)
+					if !ok {
+						continue
+					}
+					identifyReplyCommand, ok := calDataIdentifyReply.GetIdentifyReplyCommand().(readWriteModel.IdentifyReplyCommandManufacturerExactly)
+					if !ok {
+						continue
+					}
+					var remoteUrl url.URL
+					{
 						// TODO: we could check for the exact reponse
-						remoteUrl, err := url.Parse(fmt.Sprintf("tcp://%s", tcpTransportInstance.RemoteAddress))
+						remoteUrlParse, err := url.Parse(fmt.Sprintf("tcp://%s", tcpTransportInstance.RemoteAddress))
 						if err != nil {
 							log.Error().Err(err).Msg("Error creating url")
 							continue
 						}
-						// TODO: manufaturer + type would be good but this means two requests then
-						deviceName := identifyReplyCommand.GetManufacturerName()
-						discoveryEvent := &internalModel.DefaultPlcDiscoveryEvent{
-							ProtocolCode:  "c-bus",
-							TransportCode: "tcp",
-							TransportUrl:  *remoteUrl,
-							Options:       nil,
-							Name:          deviceName,
-						}
-						// Pass the event back to the callback
-						callback(discoveryEvent)
-						continue
-					case <-timeout.C:
-						timeout.Stop()
-						continue
+						remoteUrl = *remoteUrlParse
 					}
+					// TODO: manufaturer + type would be good but this means two requests then
+					deviceName := identifyReplyCommand.GetManufacturerName()
+					discoveryEvent := &internalModel.DefaultPlcDiscoveryEvent{
+						ProtocolCode:  "c-bus",
+						TransportCode: "tcp",
+						TransportUrl:  remoteUrl,
+						Options:       nil,
+						Name:          deviceName,
+					}
+					// Pass the event back to the callback
+					callback(discoveryEvent)
+					continue
+				case <-timeout.C:
+					timeout.Stop()
+					continue
 				}
-			}()
-		}
-	}()
+			}
+		}()
+	}
 	return nil
 }
diff --git a/plc4go/tests/drivers/tests/manual_cbus_driver_test.go b/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
index 98fcadc37..798111bf7 100644
--- a/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
+++ b/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
@@ -158,3 +158,24 @@ func TestManualCBusRead(t *testing.T) {
 	readRequestResult := <-readRequest.Execute()
 	fmt.Printf("%s", readRequestResult.GetResponse())
 }
+
+func TestManualDiscovery(t *testing.T) {
+	log.Logger = log.
+		With().Caller().Logger().
+		Output(zerolog.ConsoleWriter{Out: os.Stderr}).
+		Level(zerolog.TraceLevel)
+	config.TraceTransactionManagerWorkers = false
+	config.TraceTransactionManagerTransactions = false
+	config.TraceDefaultMessageCodecWorker = false
+	t.Skip()
+
+	driverManager := plc4go.NewPlcDriverManager()
+	driver := cbus.NewDriver()
+	driverManager.RegisterDriver(driver)
+	transports.RegisterTcpTransport(driverManager)
+	err := driver.Discover(func(event model.PlcDiscoveryEvent) {
+		println(event.(fmt.Stringer).String())
+	})
+	require.NoError(t, err)
+
+}