You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2022/08/17 11:23:28 UTC

[plc4x] 01/04: feat(plc4go/spi): added new ConnectWithContext to transport instance

This is an automated email from the ASF dual-hosted git repository.

sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git

commit de4fab5f1fe2dfcf2b7a432081c47e80c5d190fe
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Wed Aug 17 13:18:29 2022 +0200

    feat(plc4go/spi): added new ConnectWithContext to transport instance
---
 plc4go/spi/testutils/DriverTestRunner.go   |  9 +++-
 plc4go/spi/transports/TransportInstance.go | 67 +++++++++++++++++++++---------
 plc4go/spi/transports/pcap/Transport.go    | 11 ++++-
 plc4go/spi/transports/serial/Transport.go  | 11 ++++-
 plc4go/spi/transports/tcp/Transport.go     | 19 +++++++--
 plc4go/spi/transports/test/Transport.go    |  5 +++
 plc4go/spi/transports/udp/Transport.go     |  8 +++-
 plc4go/spi/utils/net.go                    | 25 ++++++++---
 8 files changed, 120 insertions(+), 35 deletions(-)

diff --git a/plc4go/spi/testutils/DriverTestRunner.go b/plc4go/spi/testutils/DriverTestRunner.go
index e257b110f..7604c3adb 100644
--- a/plc4go/spi/testutils/DriverTestRunner.go
+++ b/plc4go/spi/testutils/DriverTestRunner.go
@@ -74,6 +74,13 @@ func WithRootTypeParser(rootTypeParser func(utils.ReadBufferByteBased) (interfac
 	return withRootTypeParser{rootTypeParser: rootTypeParser}
 }
 
+type TestTransportInstance interface {
+	transports.TransportInstance
+	FillReadBuffer(data []uint8) error
+	GetNumDrainableBytes() uint32
+	DrainWriteBuffer(numBytes uint32) ([]uint8, error)
+}
+
 type withRootTypeParser struct {
 	option
 	rootTypeParser func(utils.ReadBufferByteBased) (interface{}, error)
@@ -136,7 +143,7 @@ func (m DriverTestsuite) ExecuteStep(connection plc4go.PlcConnection, testcase *
 	if !ok {
 		return errors.New("couldn't access connections transport instance")
 	}
-	testTransportInstance, ok := mc.GetTransportInstance().(transports.TestTransportInstance)
+	testTransportInstance, ok := mc.GetTransportInstance().(TestTransportInstance)
 	if !ok {
 		return errors.New("transport must be of type TestTransport")
 	}
diff --git a/plc4go/spi/transports/TransportInstance.go b/plc4go/spi/transports/TransportInstance.go
index 061f0330f..973557883 100644
--- a/plc4go/spi/transports/TransportInstance.go
+++ b/plc4go/spi/transports/TransportInstance.go
@@ -21,11 +21,13 @@ package transports
 
 import (
 	"bufio"
+	"context"
 	"github.com/pkg/errors"
 )
 
 type TransportInstance interface {
 	Connect() error
+	ConnectWithContext(ctx context.Context) error
 	Close() error
 
 	IsConnected() bool
@@ -40,27 +42,52 @@ type TransportInstance interface {
 	Write(data []uint8) error
 }
 
-type TestTransportInstance interface {
-	TransportInstance
-	FillReadBuffer(data []uint8) error
-	GetNumDrainableBytes() uint32
-	DrainWriteBuffer(numBytes uint32) ([]uint8, error)
+type DefaultBufferedTransportInstanceRequirements interface {
+	GetReader() *bufio.Reader
+	Connect() error
+}
+
+type DefaultBufferedTransportInstance interface {
+	ConnectWithContext(ctx context.Context) error
+	GetNumBytesAvailableInBuffer() (uint32, error)
+	FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error
+	PeekReadableBytes(numBytes uint32) ([]uint8, error)
+	Read(numBytes uint32) ([]uint8, error)
+}
+
+func NewDefaultBufferedTransportInstance(defaultBufferedTransportInstanceRequirements DefaultBufferedTransportInstanceRequirements) DefaultBufferedTransportInstance {
+	return &defaultBufferedTransportInstance{defaultBufferedTransportInstanceRequirements}
 }
 
-type DefaultBufferedTransportInstance struct {
-	*bufio.Reader
+type defaultBufferedTransportInstance struct {
+	DefaultBufferedTransportInstanceRequirements
+}
+
+// ConnectWithContext is a compatibility implementation for those transports not implementing this function
+func (m *defaultBufferedTransportInstance) ConnectWithContext(ctx context.Context) error {
+	ch := make(chan error, 1)
+	go func() {
+		ch <- m.Connect()
+		close(ch)
+	}()
+	select {
+	case err := <-ch:
+		return err
+	case <-ctx.Done():
+		return ctx.Err()
+	}
 }
 
-func (m *DefaultBufferedTransportInstance) GetNumBytesAvailableInBuffer() (uint32, error) {
-	if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) GetNumBytesAvailableInBuffer() (uint32, error) {
+	if m.GetReader() == nil {
 		return 0, nil
 	}
-	_, _ = m.Peek(1)
-	return uint32(m.Buffered()), nil
+	_, _ = m.GetReader().Peek(1)
+	return uint32(m.GetReader().Buffered()), nil
 }
 
-func (m *DefaultBufferedTransportInstance) FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error {
-	if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error {
+	if m.GetReader() == nil {
 		return nil
 	}
 	nBytes := uint32(1)
@@ -69,27 +96,27 @@ func (m *DefaultBufferedTransportInstance) FillBuffer(until func(pos uint, curre
 		if err != nil {
 			return errors.Wrap(err, "Error while peeking")
 		}
-		if keepGoing := until(uint(nBytes-1), bytes[len(bytes)-1], m.Reader); !keepGoing {
+		if keepGoing := until(uint(nBytes-1), bytes[len(bytes)-1], m.GetReader()); !keepGoing {
 			return nil
 		}
 		nBytes++
 	}
 }
 
-func (m *DefaultBufferedTransportInstance) PeekReadableBytes(numBytes uint32) ([]uint8, error) {
-	if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) PeekReadableBytes(numBytes uint32) ([]uint8, error) {
+	if m.GetReader() == nil {
 		return nil, errors.New("error peeking from transport. No reader available")
 	}
-	return m.Peek(int(numBytes))
+	return m.GetReader().Peek(int(numBytes))
 }
 
-func (m *DefaultBufferedTransportInstance) Read(numBytes uint32) ([]uint8, error) {
-	if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) Read(numBytes uint32) ([]uint8, error) {
+	if m.GetReader() == nil {
 		return nil, errors.New("error reading from transport. No reader available")
 	}
 	data := make([]uint8, numBytes)
 	for i := uint32(0); i < numBytes; i++ {
-		val, err := m.ReadByte()
+		val, err := m.GetReader().ReadByte()
 		if err != nil {
 			return nil, errors.Wrap(err, "error reading")
 		}
diff --git a/plc4go/spi/transports/pcap/Transport.go b/plc4go/spi/transports/pcap/Transport.go
index 2c5181b84..bdeea2d7d 100644
--- a/plc4go/spi/transports/pcap/Transport.go
+++ b/plc4go/spi/transports/pcap/Transport.go
@@ -88,16 +88,19 @@ type TransportInstance struct {
 	transport     *Transport
 	handle        *pcap.Handle
 	mutex         sync.Mutex
+	reader        *bufio.Reader
 }
 
 func NewPcapTransportInstance(transportFile string, transportType TransportType, portRange string, speedFactor float32, transport *Transport) *TransportInstance {
-	return &TransportInstance{
+	transportInstance := &TransportInstance{
 		transportFile: transportFile,
 		transportType: transportType,
 		portRange:     portRange,
 		speedFactor:   speedFactor,
 		transport:     transport,
 	}
+	transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+	return transportInstance
 }
 
 func (m *TransportInstance) Connect() error {
@@ -118,7 +121,7 @@ func (m *TransportInstance) Connect() error {
 	m.handle = handle
 	m.connected = true
 	buffer := new(bytes.Buffer)
-	m.Reader = bufio.NewReader(buffer)
+	m.reader = bufio.NewReader(buffer)
 
 	go func(m *TransportInstance, buffer *bytes.Buffer) {
 		packageCount := 0
@@ -186,3 +189,7 @@ func (m *TransportInstance) IsConnected() bool {
 func (m *TransportInstance) Write(_ []uint8) error {
 	panic("Write to pcap not supported")
 }
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+	return m.reader
+}
diff --git a/plc4go/spi/transports/serial/Transport.go b/plc4go/spi/transports/serial/Transport.go
index f7c046e57..1ccd68939 100644
--- a/plc4go/spi/transports/serial/Transport.go
+++ b/plc4go/spi/transports/serial/Transport.go
@@ -82,15 +82,18 @@ type TransportInstance struct {
 	ConnectTimeout uint32
 	transport      *Transport
 	serialPort     io.ReadWriteCloser
+	reader         *bufio.Reader
 }
 
 func NewTransportInstance(serialPortName string, baudRate uint, connectTimeout uint32, transport *Transport) *TransportInstance {
-	return &TransportInstance{
+	transportInstance := &TransportInstance{
 		SerialPortName: serialPortName,
 		BaudRate:       baudRate,
 		ConnectTimeout: connectTimeout,
 		transport:      transport,
 	}
+	transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+	return transportInstance
 }
 
 func (m *TransportInstance) Connect() error {
@@ -109,7 +112,7 @@ func (m *TransportInstance) Connect() error {
 		m.serialPort = utils.NewTransportLogger(m.serialPort, utils.WithLogger(fileLogger))
 		log.Trace().Msgf("Logging Transport to file %s", logFile.Name())
 	}*/
-	m.Reader = bufio.NewReader(m.serialPort)
+	m.reader = bufio.NewReader(m.serialPort)
 
 	return nil
 }
@@ -143,3 +146,7 @@ func (m *TransportInstance) Write(data []uint8) error {
 	}
 	return nil
 }
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+	return m.reader
+}
diff --git a/plc4go/spi/transports/tcp/Transport.go b/plc4go/spi/transports/tcp/Transport.go
index c8b0e8773..caf5cf147 100644
--- a/plc4go/spi/transports/tcp/Transport.go
+++ b/plc4go/spi/transports/tcp/Transport.go
@@ -21,6 +21,7 @@ package tcp
 
 import (
 	"bufio"
+	"context"
 	"fmt"
 	"github.com/apache/plc4x/plc4go/spi/transports"
 	"github.com/apache/plc4x/plc4go/spi/utils"
@@ -100,26 +101,34 @@ type TransportInstance struct {
 	ConnectTimeout uint32
 	transport      *Transport
 	tcpConn        net.Conn
+	reader         *bufio.Reader
 }
 
 func NewTcpTransportInstance(remoteAddress *net.TCPAddr, connectTimeout uint32, transport *Transport) *TransportInstance {
-	return &TransportInstance{
+	transportInstance := &TransportInstance{
 		RemoteAddress:  remoteAddress,
 		ConnectTimeout: connectTimeout,
 		transport:      transport,
 	}
+	transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+	return transportInstance
 }
 
 func (m *TransportInstance) Connect() error {
+	return m.ConnectWithContext(context.Background())
+}
+
+func (m *TransportInstance) ConnectWithContext(ctx context.Context) error {
 	var err error
-	m.tcpConn, err = net.Dial("tcp", m.RemoteAddress.String())
+	var d net.Dialer
+	m.tcpConn, err = d.DialContext(ctx, "tcp", m.RemoteAddress.String())
 	if err != nil {
 		return errors.Wrap(err, "error connecting to remote address")
 	}
 
 	m.LocalAddress = m.tcpConn.LocalAddr().(*net.TCPAddr)
 
-	m.Reader = bufio.NewReader(m.tcpConn)
+	m.reader = bufio.NewReader(m.tcpConn)
 
 	return nil
 }
@@ -153,3 +162,7 @@ func (m *TransportInstance) Write(data []uint8) error {
 	}
 	return nil
 }
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+	return m.reader
+}
diff --git a/plc4go/spi/transports/test/Transport.go b/plc4go/spi/transports/test/Transport.go
index 65b041752..ab326b819 100644
--- a/plc4go/spi/transports/test/Transport.go
+++ b/plc4go/spi/transports/test/Transport.go
@@ -22,6 +22,7 @@ package test
 import (
 	"bufio"
 	"bytes"
+	"context"
 	"github.com/apache/plc4x/plc4go/spi/transports"
 	"github.com/pkg/errors"
 	"github.com/rs/zerolog/log"
@@ -70,6 +71,10 @@ func (m *TransportInstance) Connect() error {
 	return nil
 }
 
+func (m *TransportInstance) ConnectWithContext(_ context.Context) error {
+	return m.Connect()
+}
+
 func (m *TransportInstance) Close() error {
 	log.Trace().Msg("Close")
 	m.connected = false
diff --git a/plc4go/spi/transports/udp/Transport.go b/plc4go/spi/transports/udp/Transport.go
index 001fa381e..337f9f16a 100644
--- a/plc4go/spi/transports/udp/Transport.go
+++ b/plc4go/spi/transports/udp/Transport.go
@@ -21,6 +21,7 @@ package udp
 
 import (
 	"bufio"
+	"context"
 	"github.com/apache/plc4x/plc4go/spi/transports"
 	"github.com/apache/plc4x/plc4go/spi/utils"
 	"github.com/libp2p/go-reuseport"
@@ -127,10 +128,15 @@ func NewTransportInstance(localAddress *net.UDPAddr, remoteAddress *net.UDPAddr,
 }
 
 func (m *TransportInstance) Connect() error {
+	return m.ConnectWithContext(context.Background())
+}
+
+func (m *TransportInstance) ConnectWithContext(ctx context.Context) error {
 	// If we haven't provided a local address, have the system figure it out by dialing
 	// the remote address and then using that connections local address as local address.
 	if m.LocalAddress == nil {
-		udpTest, err := net.Dial("udp", m.RemoteAddress.String())
+		var d net.Dialer
+		udpTest, err := d.DialContext(ctx, "udp", m.RemoteAddress.String())
 		if err != nil {
 			return errors.Wrap(err, "error connecting to remote address")
 		}
diff --git a/plc4go/spi/utils/net.go b/plc4go/spi/utils/net.go
index f05a9dc4c..476d08bb4 100644
--- a/plc4go/spi/utils/net.go
+++ b/plc4go/spi/utils/net.go
@@ -23,6 +23,7 @@ import (
 	"bytes"
 	"context"
 	"net"
+	"sync"
 	"time"
 
 	"github.com/google/gopacket"
@@ -32,13 +33,14 @@ import (
 	"github.com/rs/zerolog/log"
 )
 
-func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBasedScan bool) (chan net.IP, error) {
-	foundIps := make(chan net.IP, 65536)
+func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBasedScan bool) (foundIps chan net.IP, err error) {
+	foundIps = make(chan net.IP, 65536)
 	addrs, err := netInterface.Addrs()
 	if err != nil {
 		return nil, errors.Wrap(err, "Error getting addresses")
 	}
 	go func() {
+		wg := &sync.WaitGroup{}
 		for _, address := range addrs {
 			// Check if context has been cancelled before continuing
 			select {
@@ -64,17 +66,20 @@ func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBased
 
 			log.Debug().Stringer("IP", ipnet.IP).Stringer("Mask", ipnet.Mask).Msg("Expanding local subnet")
 			if useArpBasedScan {
-				if err := lockupIpsUsingArp(ctx, netInterface, ipnet, foundIps); err != nil {
+				if err := lockupIpsUsingArp(ctx, netInterface, ipnet, foundIps, wg); err != nil {
 					log.Error().Err(err).Msg("failing to resolve using arp scan. Falling back to ip based scan")
 					useArpBasedScan = false
 				}
 			}
 			if !useArpBasedScan {
-				if err := lookupIps(ctx, ipnet, foundIps); err != nil {
+				if err := lookupIps(ctx, ipnet, foundIps, wg); err != nil {
 					log.Error().Err(err).Msg("error looking up ips")
 				}
 			}
 		}
+		wg.Wait()
+		log.Trace().Msg("Closing found ips channel")
+		close(foundIps)
 	}()
 	return foundIps, nil
 }
@@ -82,7 +87,10 @@ func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBased
 // As PING operations might be blocked by a firewall, responding to ARP packets is mandatory for IP based
 // systems. So we are using an ARP scan to resolve the ethernet hardware addresses of each possible ip in range
 // Only for devices that respond will we schedule a discovery.
-func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *net.IPNet, foundIps chan net.IP) error {
+func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *net.IPNet, foundIps chan net.IP, wg *sync.WaitGroup) error {
+	// We add on signal for error handling
+	wg.Add(1)
+	go func() { wg.Done() }()
 	log.Debug().Msgf("Scanning for alive IP addresses for interface '%s' and net: %s", netInterface.Name, ipNet)
 	// First find the pcap device name for the given interface.
 	allDevs, _ := pcap.FindAllDevs()
@@ -108,6 +116,8 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
 
 	// Start up a goroutine to read in packet data.
 	stop := make(chan struct{})
+	// As we don't know how much the handler will find we use a value of 1 and set that to done after the 10 sec in the cleanup function directly after
+	wg.Add(1)
 	// Handler for processing incoming ARP responses.
 	go func(handle *pcap.Handle, iface net.Interface, stop chan struct{}) {
 		src := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
@@ -146,6 +156,7 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
 	// Make sure we clean up after 10 seconds.
 	defer func() {
 		go func() {
+			wg.Done()
 			time.Sleep(10 * time.Second)
 			handle.Close()
 			close(stop)
@@ -202,7 +213,7 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
 }
 
 // Simply takes the IP address and the netmask and schedules one discovery task for every possible IP
-func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP) error {
+func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP, wg *sync.WaitGroup) error {
 	log.Debug().Msgf("Scanning all IP addresses for network: %s", ipnet)
 	// expand CIDR-block into one target for each IP
 	// Remark: The last IP address a network contains is a special broadcast address. We don't want to check that one.
@@ -214,7 +225,9 @@ func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP) erro
 		default:
 		}
 
+		wg.Add(1)
 		go func(ip net.IP) {
+			defer func() { wg.Done() }()
 			select {
 			case <-ctx.Done():
 			case foundIps <- ip: