You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2022/08/17 11:23:28 UTC
[plc4x] 01/04: feat(plc4go/spi): added new ConnectWithContext to transport instance
This is an automated email from the ASF dual-hosted git repository.
sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git
commit de4fab5f1fe2dfcf2b7a432081c47e80c5d190fe
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Wed Aug 17 13:18:29 2022 +0200
feat(plc4go/spi): added new ConnectWithContext to transport instance
---
plc4go/spi/testutils/DriverTestRunner.go | 9 +++-
plc4go/spi/transports/TransportInstance.go | 67 +++++++++++++++++++++---------
plc4go/spi/transports/pcap/Transport.go | 11 ++++-
plc4go/spi/transports/serial/Transport.go | 11 ++++-
plc4go/spi/transports/tcp/Transport.go | 19 +++++++--
plc4go/spi/transports/test/Transport.go | 5 +++
plc4go/spi/transports/udp/Transport.go | 8 +++-
plc4go/spi/utils/net.go | 25 ++++++++---
8 files changed, 120 insertions(+), 35 deletions(-)
diff --git a/plc4go/spi/testutils/DriverTestRunner.go b/plc4go/spi/testutils/DriverTestRunner.go
index e257b110f..7604c3adb 100644
--- a/plc4go/spi/testutils/DriverTestRunner.go
+++ b/plc4go/spi/testutils/DriverTestRunner.go
@@ -74,6 +74,13 @@ func WithRootTypeParser(rootTypeParser func(utils.ReadBufferByteBased) (interfac
return withRootTypeParser{rootTypeParser: rootTypeParser}
}
+type TestTransportInstance interface {
+ transports.TransportInstance
+ FillReadBuffer(data []uint8) error
+ GetNumDrainableBytes() uint32
+ DrainWriteBuffer(numBytes uint32) ([]uint8, error)
+}
+
type withRootTypeParser struct {
option
rootTypeParser func(utils.ReadBufferByteBased) (interface{}, error)
@@ -136,7 +143,7 @@ func (m DriverTestsuite) ExecuteStep(connection plc4go.PlcConnection, testcase *
if !ok {
return errors.New("couldn't access connections transport instance")
}
- testTransportInstance, ok := mc.GetTransportInstance().(transports.TestTransportInstance)
+ testTransportInstance, ok := mc.GetTransportInstance().(TestTransportInstance)
if !ok {
return errors.New("transport must be of type TestTransport")
}
diff --git a/plc4go/spi/transports/TransportInstance.go b/plc4go/spi/transports/TransportInstance.go
index 061f0330f..973557883 100644
--- a/plc4go/spi/transports/TransportInstance.go
+++ b/plc4go/spi/transports/TransportInstance.go
@@ -21,11 +21,13 @@ package transports
import (
"bufio"
+ "context"
"github.com/pkg/errors"
)
type TransportInstance interface {
Connect() error
+ ConnectWithContext(ctx context.Context) error
Close() error
IsConnected() bool
@@ -40,27 +42,52 @@ type TransportInstance interface {
Write(data []uint8) error
}
-type TestTransportInstance interface {
- TransportInstance
- FillReadBuffer(data []uint8) error
- GetNumDrainableBytes() uint32
- DrainWriteBuffer(numBytes uint32) ([]uint8, error)
+type DefaultBufferedTransportInstanceRequirements interface {
+ GetReader() *bufio.Reader
+ Connect() error
+}
+
+type DefaultBufferedTransportInstance interface {
+ ConnectWithContext(ctx context.Context) error
+ GetNumBytesAvailableInBuffer() (uint32, error)
+ FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error
+ PeekReadableBytes(numBytes uint32) ([]uint8, error)
+ Read(numBytes uint32) ([]uint8, error)
+}
+
+func NewDefaultBufferedTransportInstance(defaultBufferedTransportInstanceRequirements DefaultBufferedTransportInstanceRequirements) DefaultBufferedTransportInstance {
+ return &defaultBufferedTransportInstance{defaultBufferedTransportInstanceRequirements}
}
-type DefaultBufferedTransportInstance struct {
- *bufio.Reader
+type defaultBufferedTransportInstance struct {
+ DefaultBufferedTransportInstanceRequirements
+}
+
+// ConnectWithContext is a compatibility implementation for those transports not implementing this function
+func (m *defaultBufferedTransportInstance) ConnectWithContext(ctx context.Context) error {
+ ch := make(chan error, 1)
+ go func() {
+ ch <- m.Connect()
+ close(ch)
+ }()
+ select {
+ case err := <-ch:
+ return err
+ case <-ctx.Done():
+ return ctx.Err()
+ }
}
-func (m *DefaultBufferedTransportInstance) GetNumBytesAvailableInBuffer() (uint32, error) {
- if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) GetNumBytesAvailableInBuffer() (uint32, error) {
+ if m.GetReader() == nil {
return 0, nil
}
- _, _ = m.Peek(1)
- return uint32(m.Buffered()), nil
+ _, _ = m.GetReader().Peek(1)
+ return uint32(m.GetReader().Buffered()), nil
}
-func (m *DefaultBufferedTransportInstance) FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error {
- if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) FillBuffer(until func(pos uint, currentByte byte, reader *bufio.Reader) bool) error {
+ if m.GetReader() == nil {
return nil
}
nBytes := uint32(1)
@@ -69,27 +96,27 @@ func (m *DefaultBufferedTransportInstance) FillBuffer(until func(pos uint, curre
if err != nil {
return errors.Wrap(err, "Error while peeking")
}
- if keepGoing := until(uint(nBytes-1), bytes[len(bytes)-1], m.Reader); !keepGoing {
+ if keepGoing := until(uint(nBytes-1), bytes[len(bytes)-1], m.GetReader()); !keepGoing {
return nil
}
nBytes++
}
}
-func (m *DefaultBufferedTransportInstance) PeekReadableBytes(numBytes uint32) ([]uint8, error) {
- if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) PeekReadableBytes(numBytes uint32) ([]uint8, error) {
+ if m.GetReader() == nil {
return nil, errors.New("error peeking from transport. No reader available")
}
- return m.Peek(int(numBytes))
+ return m.GetReader().Peek(int(numBytes))
}
-func (m *DefaultBufferedTransportInstance) Read(numBytes uint32) ([]uint8, error) {
- if m.Reader == nil {
+func (m *defaultBufferedTransportInstance) Read(numBytes uint32) ([]uint8, error) {
+ if m.GetReader() == nil {
return nil, errors.New("error reading from transport. No reader available")
}
data := make([]uint8, numBytes)
for i := uint32(0); i < numBytes; i++ {
- val, err := m.ReadByte()
+ val, err := m.GetReader().ReadByte()
if err != nil {
return nil, errors.Wrap(err, "error reading")
}
diff --git a/plc4go/spi/transports/pcap/Transport.go b/plc4go/spi/transports/pcap/Transport.go
index 2c5181b84..bdeea2d7d 100644
--- a/plc4go/spi/transports/pcap/Transport.go
+++ b/plc4go/spi/transports/pcap/Transport.go
@@ -88,16 +88,19 @@ type TransportInstance struct {
transport *Transport
handle *pcap.Handle
mutex sync.Mutex
+ reader *bufio.Reader
}
func NewPcapTransportInstance(transportFile string, transportType TransportType, portRange string, speedFactor float32, transport *Transport) *TransportInstance {
- return &TransportInstance{
+ transportInstance := &TransportInstance{
transportFile: transportFile,
transportType: transportType,
portRange: portRange,
speedFactor: speedFactor,
transport: transport,
}
+ transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+ return transportInstance
}
func (m *TransportInstance) Connect() error {
@@ -118,7 +121,7 @@ func (m *TransportInstance) Connect() error {
m.handle = handle
m.connected = true
buffer := new(bytes.Buffer)
- m.Reader = bufio.NewReader(buffer)
+ m.reader = bufio.NewReader(buffer)
go func(m *TransportInstance, buffer *bytes.Buffer) {
packageCount := 0
@@ -186,3 +189,7 @@ func (m *TransportInstance) IsConnected() bool {
func (m *TransportInstance) Write(_ []uint8) error {
panic("Write to pcap not supported")
}
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+ return m.reader
+}
diff --git a/plc4go/spi/transports/serial/Transport.go b/plc4go/spi/transports/serial/Transport.go
index f7c046e57..1ccd68939 100644
--- a/plc4go/spi/transports/serial/Transport.go
+++ b/plc4go/spi/transports/serial/Transport.go
@@ -82,15 +82,18 @@ type TransportInstance struct {
ConnectTimeout uint32
transport *Transport
serialPort io.ReadWriteCloser
+ reader *bufio.Reader
}
func NewTransportInstance(serialPortName string, baudRate uint, connectTimeout uint32, transport *Transport) *TransportInstance {
- return &TransportInstance{
+ transportInstance := &TransportInstance{
SerialPortName: serialPortName,
BaudRate: baudRate,
ConnectTimeout: connectTimeout,
transport: transport,
}
+ transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+ return transportInstance
}
func (m *TransportInstance) Connect() error {
@@ -109,7 +112,7 @@ func (m *TransportInstance) Connect() error {
m.serialPort = utils.NewTransportLogger(m.serialPort, utils.WithLogger(fileLogger))
log.Trace().Msgf("Logging Transport to file %s", logFile.Name())
}*/
- m.Reader = bufio.NewReader(m.serialPort)
+ m.reader = bufio.NewReader(m.serialPort)
return nil
}
@@ -143,3 +146,7 @@ func (m *TransportInstance) Write(data []uint8) error {
}
return nil
}
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+ return m.reader
+}
diff --git a/plc4go/spi/transports/tcp/Transport.go b/plc4go/spi/transports/tcp/Transport.go
index c8b0e8773..caf5cf147 100644
--- a/plc4go/spi/transports/tcp/Transport.go
+++ b/plc4go/spi/transports/tcp/Transport.go
@@ -21,6 +21,7 @@ package tcp
import (
"bufio"
+ "context"
"fmt"
"github.com/apache/plc4x/plc4go/spi/transports"
"github.com/apache/plc4x/plc4go/spi/utils"
@@ -100,26 +101,34 @@ type TransportInstance struct {
ConnectTimeout uint32
transport *Transport
tcpConn net.Conn
+ reader *bufio.Reader
}
func NewTcpTransportInstance(remoteAddress *net.TCPAddr, connectTimeout uint32, transport *Transport) *TransportInstance {
- return &TransportInstance{
+ transportInstance := &TransportInstance{
RemoteAddress: remoteAddress,
ConnectTimeout: connectTimeout,
transport: transport,
}
+ transportInstance.DefaultBufferedTransportInstance = transports.NewDefaultBufferedTransportInstance(transportInstance)
+ return transportInstance
}
func (m *TransportInstance) Connect() error {
+ return m.ConnectWithContext(context.Background())
+}
+
+func (m *TransportInstance) ConnectWithContext(ctx context.Context) error {
var err error
- m.tcpConn, err = net.Dial("tcp", m.RemoteAddress.String())
+ var d net.Dialer
+ m.tcpConn, err = d.DialContext(ctx, "tcp", m.RemoteAddress.String())
if err != nil {
return errors.Wrap(err, "error connecting to remote address")
}
m.LocalAddress = m.tcpConn.LocalAddr().(*net.TCPAddr)
- m.Reader = bufio.NewReader(m.tcpConn)
+ m.reader = bufio.NewReader(m.tcpConn)
return nil
}
@@ -153,3 +162,7 @@ func (m *TransportInstance) Write(data []uint8) error {
}
return nil
}
+
+func (m *TransportInstance) GetReader() *bufio.Reader {
+ return m.reader
+}
diff --git a/plc4go/spi/transports/test/Transport.go b/plc4go/spi/transports/test/Transport.go
index 65b041752..ab326b819 100644
--- a/plc4go/spi/transports/test/Transport.go
+++ b/plc4go/spi/transports/test/Transport.go
@@ -22,6 +22,7 @@ package test
import (
"bufio"
"bytes"
+ "context"
"github.com/apache/plc4x/plc4go/spi/transports"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
@@ -70,6 +71,10 @@ func (m *TransportInstance) Connect() error {
return nil
}
+func (m *TransportInstance) ConnectWithContext(_ context.Context) error {
+ return m.Connect()
+}
+
func (m *TransportInstance) Close() error {
log.Trace().Msg("Close")
m.connected = false
diff --git a/plc4go/spi/transports/udp/Transport.go b/plc4go/spi/transports/udp/Transport.go
index 001fa381e..337f9f16a 100644
--- a/plc4go/spi/transports/udp/Transport.go
+++ b/plc4go/spi/transports/udp/Transport.go
@@ -21,6 +21,7 @@ package udp
import (
"bufio"
+ "context"
"github.com/apache/plc4x/plc4go/spi/transports"
"github.com/apache/plc4x/plc4go/spi/utils"
"github.com/libp2p/go-reuseport"
@@ -127,10 +128,15 @@ func NewTransportInstance(localAddress *net.UDPAddr, remoteAddress *net.UDPAddr,
}
func (m *TransportInstance) Connect() error {
+ return m.ConnectWithContext(context.Background())
+}
+
+func (m *TransportInstance) ConnectWithContext(ctx context.Context) error {
// If we haven't provided a local address, have the system figure it out by dialing
// the remote address and then using that connections local address as local address.
if m.LocalAddress == nil {
- udpTest, err := net.Dial("udp", m.RemoteAddress.String())
+ var d net.Dialer
+ udpTest, err := d.DialContext(ctx, "udp", m.RemoteAddress.String())
if err != nil {
return errors.Wrap(err, "error connecting to remote address")
}
diff --git a/plc4go/spi/utils/net.go b/plc4go/spi/utils/net.go
index f05a9dc4c..476d08bb4 100644
--- a/plc4go/spi/utils/net.go
+++ b/plc4go/spi/utils/net.go
@@ -23,6 +23,7 @@ import (
"bytes"
"context"
"net"
+ "sync"
"time"
"github.com/google/gopacket"
@@ -32,13 +33,14 @@ import (
"github.com/rs/zerolog/log"
)
-func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBasedScan bool) (chan net.IP, error) {
- foundIps := make(chan net.IP, 65536)
+func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBasedScan bool) (foundIps chan net.IP, err error) {
+ foundIps = make(chan net.IP, 65536)
addrs, err := netInterface.Addrs()
if err != nil {
return nil, errors.Wrap(err, "Error getting addresses")
}
go func() {
+ wg := &sync.WaitGroup{}
for _, address := range addrs {
// Check if context has been cancelled before continuing
select {
@@ -64,17 +66,20 @@ func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBased
log.Debug().Stringer("IP", ipnet.IP).Stringer("Mask", ipnet.Mask).Msg("Expanding local subnet")
if useArpBasedScan {
- if err := lockupIpsUsingArp(ctx, netInterface, ipnet, foundIps); err != nil {
+ if err := lockupIpsUsingArp(ctx, netInterface, ipnet, foundIps, wg); err != nil {
log.Error().Err(err).Msg("failing to resolve using arp scan. Falling back to ip based scan")
useArpBasedScan = false
}
}
if !useArpBasedScan {
- if err := lookupIps(ctx, ipnet, foundIps); err != nil {
+ if err := lookupIps(ctx, ipnet, foundIps, wg); err != nil {
log.Error().Err(err).Msg("error looking up ips")
}
}
}
+ wg.Wait()
+ log.Trace().Msg("Closing found ips channel")
+ close(foundIps)
}()
return foundIps, nil
}
@@ -82,7 +87,10 @@ func GetIPAddresses(ctx context.Context, netInterface net.Interface, useArpBased
// As PING operations might be blocked by a firewall, responding to ARP packets is mandatory for IP based
// systems. So we are using an ARP scan to resolve the ethernet hardware addresses of each possible ip in range
// Only for devices that respond will we schedule a discovery.
-func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *net.IPNet, foundIps chan net.IP) error {
+func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *net.IPNet, foundIps chan net.IP, wg *sync.WaitGroup) error {
+ // We add on signal for error handling
+ wg.Add(1)
+ go func() { wg.Done() }()
log.Debug().Msgf("Scanning for alive IP addresses for interface '%s' and net: %s", netInterface.Name, ipNet)
// First find the pcap device name for the given interface.
allDevs, _ := pcap.FindAllDevs()
@@ -108,6 +116,8 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
// Start up a goroutine to read in packet data.
stop := make(chan struct{})
+ // As we don't know how much the handler will find we use a value of 1 and set that to done after the 10 sec in the cleanup function directly after
+ wg.Add(1)
// Handler for processing incoming ARP responses.
go func(handle *pcap.Handle, iface net.Interface, stop chan struct{}) {
src := gopacket.NewPacketSource(handle, layers.LayerTypeEthernet)
@@ -146,6 +156,7 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
// Make sure we clean up after 10 seconds.
defer func() {
go func() {
+ wg.Done()
time.Sleep(10 * time.Second)
handle.Close()
close(stop)
@@ -202,7 +213,7 @@ func lockupIpsUsingArp(ctx context.Context, netInterface net.Interface, ipNet *n
}
// Simply takes the IP address and the netmask and schedules one discovery task for every possible IP
-func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP) error {
+func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP, wg *sync.WaitGroup) error {
log.Debug().Msgf("Scanning all IP addresses for network: %s", ipnet)
// expand CIDR-block into one target for each IP
// Remark: The last IP address a network contains is a special broadcast address. We don't want to check that one.
@@ -214,7 +225,9 @@ func lookupIps(ctx context.Context, ipnet *net.IPNet, foundIps chan net.IP) erro
default:
}
+ wg.Add(1)
go func(ip net.IP) {
+ defer func() { wg.Done() }()
select {
case <-ctx.Done():
case foundIps <- ip: