You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2022/08/17 11:23:31 UTC
[plc4x] 04/04: fix(plc4go/cbus): avoid channel leak by adding wait groups
This is an automated email from the ASF dual-hosted git repository.
sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git
commit ce2db6b314c397ffa092aa4046941a6813d3a798
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Wed Aug 17 13:23:19 2022 +0200
fix(plc4go/cbus): avoid channel leak by adding wait groups
---
plc4go/internal/cbus/Discoverer.go | 207 ++++++++++++---------
.../tests/drivers/tests/manual_cbus_driver_test.go | 21 +++
2 files changed, 135 insertions(+), 93 deletions(-)
diff --git a/plc4go/internal/cbus/Discoverer.go b/plc4go/internal/cbus/Discoverer.go
index 731c2b4a9..3ef576304 100644
--- a/plc4go/internal/cbus/Discoverer.go
+++ b/plc4go/internal/cbus/Discoverer.go
@@ -25,6 +25,7 @@ import (
"github.com/apache/plc4x/plc4go/spi/transports/tcp"
"net"
"net/url"
+ "sync"
"time"
apiModel "github.com/apache/plc4x/plc4go/pkg/api/model"
@@ -73,13 +74,16 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}
transportInstances := make(chan transports.TransportInstance)
+ wg := &sync.WaitGroup{}
// Iterate over all network devices of this system.
for _, netInterface := range interfaces {
addrs, err := netInterface.Addrs()
if err != nil {
return err
}
+ wg.Add(1)
go func(netInterface net.Interface) {
+ defer func() { wg.Done() }()
// Iterate over all addresses the current interface has configured
// For KNX we're only interested in IPv4 addresses, as it doesn't
// seem to work with IPv6.
@@ -106,24 +110,34 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
log.Warn().Err(err).Msgf("Can't get addresses for %v", netInterface)
continue
}
+ wg.Add(1)
go func() {
+ defer func() { wg.Done() }()
for ip := range addresses {
+ log.Trace().Msgf("Handling found ip %v", ip)
+ wg.Add(1)
go func(ip net.IP) {
+ defer func() { wg.Done() }()
// Create a new "connection" (Actually open a local udp socket and target outgoing packets to that address)
- connectionUrl, err := url.Parse(fmt.Sprintf("tcp://%s:%d", ip, readWriteModel.CBusConstants_CBUSTCPDEFAULTPORT))
- if err != nil {
- log.Error().Err(err).Msgf("Error parsing url for lookup")
- return
+ var connectionUrl url.URL
+ {
+ connectionUrlParsed, err := url.Parse(fmt.Sprintf("tcp://%s:%d", ip, readWriteModel.CBusConstants_CBUSTCPDEFAULTPORT))
+ if err != nil {
+ log.Error().Err(err).Msgf("Error parsing url for lookup")
+ return
+ }
+ connectionUrl = *connectionUrlParsed
}
- transportInstance, err := tcpTransport.CreateTransportInstance(*connectionUrl, nil)
+
+ transportInstance, err := tcpTransport.CreateTransportInstance(connectionUrl, nil)
if err != nil {
log.Error().Err(err).Msgf("Error creating transport instance")
return
}
- log.Trace().Msgf("trying %s", connectionUrl)
- err = transportInstance.Connect()
+ log.Trace().Msgf("trying %v", connectionUrl)
+ err = transportInstance.ConnectWithContext(ctx)
if err != nil {
- secondErr := transportInstance.Connect()
+ secondErr := transportInstance.ConnectWithContext(ctx)
if secondErr != nil {
log.Trace().Err(err).Msgf("Error connecting transport instance")
return
@@ -137,99 +151,106 @@ func (d *Discoverer) Discover(ctx context.Context, callback func(event apiModel.
}
}(netInterface)
}
-
go func() {
- for transportInstance := range transportInstances {
- tcpTransportInstance := transportInstance.(*tcp.TransportInstance)
- // Create a codec for sending and receiving messages.
- codec := NewMessageCodec(transportInstance)
- // Explicitly start the worker
- if err := codec.Connect(); err != nil {
- log.Debug().Err(err).Msg("Error connecting")
- continue
- }
+ wg.Wait()
+ log.Trace().Msg("Closing transport instance channel")
+ close(transportInstances)
+ }()
- // Prepare the discovery packet data
- cBusOptions := readWriteModel.NewCBusOptions(false, false, false, false, false, false, false, false, true)
- requestContext := readWriteModel.NewRequestContext(false)
- calData := readWriteModel.NewCALDataIdentify(readWriteModel.Attribute_Manufacturer, readWriteModel.CALCommandTypeContainer_CALCommandIdentify, nil, requestContext)
- alpha := readWriteModel.NewAlpha('x')
- request := readWriteModel.NewRequestDirectCommandAccess(calData, alpha, 0x0, nil, nil, readWriteModel.RequestType_DIRECT_COMMAND, readWriteModel.NewRequestTermination(), cBusOptions)
- cBusMessageToServer := readWriteModel.NewCBusMessageToServer(request, requestContext, cBusOptions)
- // Send the search request.
- err = codec.Send(cBusMessageToServer)
- go func() {
- // Keep on reading responses till the timeout is done.
- // TODO: Make this configurable
- timeout := time.NewTimer(time.Second * 1)
- timeout.Stop()
- for start := time.Now(); time.Since(start) < time.Second*5; {
- timeout.Reset(time.Second * 1)
- select {
- case receivedMessage := <-codec.GetDefaultIncomingMessageChannel():
- if !timeout.Stop() {
- <-timeout.C
- }
- cbusMessage, ok := receivedMessage.(readWriteModel.CBusMessage)
- if !ok {
- continue
- }
- messageToClient, ok := cbusMessage.(readWriteModel.CBusMessageToClient)
- if !ok {
- continue
- }
- replyOrConfirmationConfirmation, ok := messageToClient.GetReply().(readWriteModel.ReplyOrConfirmationConfirmationExactly)
- if !ok {
- continue
- }
- if receivedAlpha := replyOrConfirmationConfirmation.GetConfirmation().GetAlpha(); receivedAlpha != nil && alpha.GetCharacter() != receivedAlpha.GetCharacter() {
- continue
- }
- embeddedReply, ok := replyOrConfirmationConfirmation.GetEmbeddedReply().(readWriteModel.ReplyOrConfirmationReplyExactly)
- if !ok {
- continue
- }
- encodedReply, ok := embeddedReply.GetReply().(readWriteModel.ReplyEncodedReplyExactly)
- if !ok {
- continue
- }
- encodedReplyCALReply, ok := encodedReply.GetEncodedReply().(readWriteModel.EncodedReplyCALReplyExactly)
- if !ok {
- continue
- }
- calDataIdentifyReply, ok := encodedReplyCALReply.GetCalReply().GetCalData().(readWriteModel.CALDataIdentifyReplyExactly)
- if !ok {
- continue
- }
- identifyReplyCommand, ok := calDataIdentifyReply.GetIdentifyReplyCommand().(readWriteModel.IdentifyReplyCommandManufacturerExactly)
- if !ok {
- continue
- }
+ for transportInstance := range transportInstances {
+ tcpTransportInstance := transportInstance.(*tcp.TransportInstance)
+ // Create a codec for sending and receiving messages.
+ codec := NewMessageCodec(transportInstance)
+ // Explicitly start the worker
+ if err := codec.Connect(); err != nil {
+ log.Debug().Err(err).Msg("Error connecting")
+ continue
+ }
+
+ // Prepare the discovery packet data
+ cBusOptions := readWriteModel.NewCBusOptions(false, false, false, false, false, false, false, false, true)
+ requestContext := readWriteModel.NewRequestContext(false)
+ calData := readWriteModel.NewCALDataIdentify(readWriteModel.Attribute_Manufacturer, readWriteModel.CALCommandTypeContainer_CALCommandIdentify, nil, requestContext)
+ alpha := readWriteModel.NewAlpha('x')
+ request := readWriteModel.NewRequestDirectCommandAccess(calData, alpha, 0x0, nil, nil, readWriteModel.RequestType_DIRECT_COMMAND, readWriteModel.NewRequestTermination(), cBusOptions)
+ cBusMessageToServer := readWriteModel.NewCBusMessageToServer(request, requestContext, cBusOptions)
+ // Send the search request.
+ err = codec.Send(cBusMessageToServer)
+ go func() {
+ // Keep on reading responses till the timeout is done.
+ // TODO: Make this configurable
+ timeout := time.NewTimer(time.Second * 1)
+ timeout.Stop()
+ for start := time.Now(); time.Since(start) < time.Second*5; {
+ timeout.Reset(time.Second * 1)
+ select {
+ case receivedMessage := <-codec.GetDefaultIncomingMessageChannel():
+ if !timeout.Stop() {
+ <-timeout.C
+ }
+ cbusMessage, ok := receivedMessage.(readWriteModel.CBusMessage)
+ if !ok {
+ continue
+ }
+ messageToClient, ok := cbusMessage.(readWriteModel.CBusMessageToClient)
+ if !ok {
+ continue
+ }
+ replyOrConfirmationConfirmation, ok := messageToClient.GetReply().(readWriteModel.ReplyOrConfirmationConfirmationExactly)
+ if !ok {
+ continue
+ }
+ if receivedAlpha := replyOrConfirmationConfirmation.GetConfirmation().GetAlpha(); receivedAlpha != nil && alpha.GetCharacter() != receivedAlpha.GetCharacter() {
+ continue
+ }
+ embeddedReply, ok := replyOrConfirmationConfirmation.GetEmbeddedReply().(readWriteModel.ReplyOrConfirmationReplyExactly)
+ if !ok {
+ continue
+ }
+ encodedReply, ok := embeddedReply.GetReply().(readWriteModel.ReplyEncodedReplyExactly)
+ if !ok {
+ continue
+ }
+ encodedReplyCALReply, ok := encodedReply.GetEncodedReply().(readWriteModel.EncodedReplyCALReplyExactly)
+ if !ok {
+ continue
+ }
+ calDataIdentifyReply, ok := encodedReplyCALReply.GetCalReply().GetCalData().(readWriteModel.CALDataIdentifyReplyExactly)
+ if !ok {
+ continue
+ }
+ identifyReplyCommand, ok := calDataIdentifyReply.GetIdentifyReplyCommand().(readWriteModel.IdentifyReplyCommandManufacturerExactly)
+ if !ok {
+ continue
+ }
+ var remoteUrl url.URL
+ {
// TODO: we could check for the exact reponse
- remoteUrl, err := url.Parse(fmt.Sprintf("tcp://%s", tcpTransportInstance.RemoteAddress))
+ remoteUrlParse, err := url.Parse(fmt.Sprintf("tcp://%s", tcpTransportInstance.RemoteAddress))
if err != nil {
log.Error().Err(err).Msg("Error creating url")
continue
}
- // TODO: manufaturer + type would be good but this means two requests then
- deviceName := identifyReplyCommand.GetManufacturerName()
- discoveryEvent := &internalModel.DefaultPlcDiscoveryEvent{
- ProtocolCode: "c-bus",
- TransportCode: "tcp",
- TransportUrl: *remoteUrl,
- Options: nil,
- Name: deviceName,
- }
- // Pass the event back to the callback
- callback(discoveryEvent)
- continue
- case <-timeout.C:
- timeout.Stop()
- continue
+ remoteUrl = *remoteUrlParse
}
+ // TODO: manufaturer + type would be good but this means two requests then
+ deviceName := identifyReplyCommand.GetManufacturerName()
+ discoveryEvent := &internalModel.DefaultPlcDiscoveryEvent{
+ ProtocolCode: "c-bus",
+ TransportCode: "tcp",
+ TransportUrl: remoteUrl,
+ Options: nil,
+ Name: deviceName,
+ }
+ // Pass the event back to the callback
+ callback(discoveryEvent)
+ continue
+ case <-timeout.C:
+ timeout.Stop()
+ continue
}
- }()
- }
- }()
+ }
+ }()
+ }
return nil
}
diff --git a/plc4go/tests/drivers/tests/manual_cbus_driver_test.go b/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
index 98fcadc37..798111bf7 100644
--- a/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
+++ b/plc4go/tests/drivers/tests/manual_cbus_driver_test.go
@@ -158,3 +158,24 @@ func TestManualCBusRead(t *testing.T) {
readRequestResult := <-readRequest.Execute()
fmt.Printf("%s", readRequestResult.GetResponse())
}
+
+func TestManualDiscovery(t *testing.T) {
+ log.Logger = log.
+ With().Caller().Logger().
+ Output(zerolog.ConsoleWriter{Out: os.Stderr}).
+ Level(zerolog.TraceLevel)
+ config.TraceTransactionManagerWorkers = false
+ config.TraceTransactionManagerTransactions = false
+ config.TraceDefaultMessageCodecWorker = false
+ t.Skip()
+
+ driverManager := plc4go.NewPlcDriverManager()
+ driver := cbus.NewDriver()
+ driverManager.RegisterDriver(driver)
+ transports.RegisterTcpTransport(driverManager)
+ err := driver.Discover(func(event model.PlcDiscoveryEvent) {
+ println(event.(fmt.Stringer).String())
+ })
+ require.NoError(t, err)
+
+}