You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by ld...@apache.org on 2024/02/12 14:34:44 UTC

(plc4x) 01/08: fix(plc4go): options should now correctly be applied

This is an automated email from the ASF dual-hosted git repository.

ldywicki pushed a commit to branch pg/security-policy
in repository https://gitbox.apache.org/repos/asf/plc4x.git

commit 9f658e6ded38ab2fa4d621dbca56297e0f86b2d4
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Mon Feb 12 11:40:50 2024 +0100

    fix(plc4go): options should now correctly be applied
---
 plc4go/internal/cbus/Configuration.go           | 10 +++++
 plc4go/internal/opcua/Configuration.go          | 54 +++++++++++++++----------
 plc4go/internal/opcua/Configuration_plc4xgen.go | 38 ++++++++---------
 plc4go/internal/opcua/Driver.go                 | 12 +++---
 plc4go/internal/opcua/DriverContext.go          |  2 +-
 plc4go/internal/opcua/SecureChannel.go          | 44 ++++++++++----------
 6 files changed, 90 insertions(+), 70 deletions(-)

diff --git a/plc4go/internal/cbus/Configuration.go b/plc4go/internal/cbus/Configuration.go
index 2bfdad9a3e..f3a787dbd0 100644
--- a/plc4go/internal/cbus/Configuration.go
+++ b/plc4go/internal/cbus/Configuration.go
@@ -21,6 +21,8 @@ package cbus
 
 import (
 	"github.com/rs/zerolog"
+	"golang.org/x/text/cases"
+	"golang.org/x/text/language"
 	"reflect"
 	"strconv"
 
@@ -45,6 +47,7 @@ type Configuration struct {
 }
 
 func ParseFromOptions(log zerolog.Logger, options map[string][]string) (Configuration, error) {
+	titleOptions(options)
 	configuration := createDefaultConfiguration()
 	reflectConfiguration := reflect.ValueOf(&configuration).Elem()
 	for i := 0; i < reflectConfiguration.NumField(); i++ {
@@ -72,6 +75,13 @@ func ParseFromOptions(log zerolog.Logger, options map[string][]string) (Configur
 	return configuration, nil
 }
 
+func titleOptions(options map[string][]string) {
+	caser := cases.Title(language.AmericanEnglish)
+	for key, value := range options {
+		options[caser.String(key)] = value
+	}
+}
+
 func createDefaultConfiguration() Configuration {
 	return Configuration{
 		Exstat:   true,
diff --git a/plc4go/internal/opcua/Configuration.go b/plc4go/internal/opcua/Configuration.go
index 21dcecea76..e6684b29fc 100644
--- a/plc4go/internal/opcua/Configuration.go
+++ b/plc4go/internal/opcua/Configuration.go
@@ -29,32 +29,35 @@ import (
 
 	"github.com/pkg/errors"
 	"github.com/rs/zerolog"
+	"golang.org/x/text/cases"
+	"golang.org/x/text/language"
 )
 
 //go:generate go run ../../tools/plc4xgenerator/gen.go -type=Configuration
 type Configuration struct {
-	code              string
-	host              string
-	port              string
-	endpoint          string
-	transportEndpoint string
-	params            string
-	isEncrypted       bool
-	thumbprint        readWriteModel.PascalByteString
-	senderCertificate []byte
-	discovery         bool
-	username          string
-	password          string
-	securityPolicy    string
-	keyStoreFile      string
-	certDirectory     string
-	keyStorePassword  string
-	ckp               *CertificateKeyPair
+	Code              string
+	Host              string
+	Port              string
+	Endpoint          string
+	TransportEndpoint string
+	Params            string
+	IsEncrypted       bool
+	Thumbprint        readWriteModel.PascalByteString
+	SenderCertificate []byte
+	Discovery         bool
+	Username          string
+	Password          string
+	SecurityPolicy    string
+	KeyStoreFile      string
+	CertDirectory     string
+	KeyStorePassword  string
+	Ckp               *CertificateKeyPair
 
 	log zerolog.Logger `ignore:"true"`
 }
 
 func ParseFromOptions(log zerolog.Logger, options map[string][]string) (Configuration, error) {
+	titleOptions(options)
 	configuration := createDefaultConfiguration()
 	reflectConfiguration := reflect.ValueOf(&configuration).Elem()
 	for i := 0; i < reflectConfiguration.NumField(); i++ {
@@ -83,19 +86,26 @@ func ParseFromOptions(log zerolog.Logger, options map[string][]string) (Configur
 	return configuration, nil
 }
 
+func titleOptions(options map[string][]string) {
+	caser := cases.Title(language.AmericanEnglish)
+	for key, value := range options {
+		options[caser.String(key)] = value
+	}
+}
+
 func (c *Configuration) openKeyStore() error {
-	c.isEncrypted = true
-	securityTempDir := path.Join(c.certDirectory, "security")
+	c.IsEncrypted = true
+	securityTempDir := path.Join(c.CertDirectory, "security")
 	if _, err := os.Stat(securityTempDir); errors.Is(err, os.ErrNotExist) {
 		if err := os.Mkdir(securityTempDir, 700); err != nil {
 			return errors.New("Unable to create directory please confirm folder permissions on " + securityTempDir)
 		}
 	}
 
-	serverKeyStore := path.Join(securityTempDir, c.keyStoreFile)
+	serverKeyStore := path.Join(securityTempDir, c.KeyStoreFile)
 	if _, err := os.Stat(securityTempDir); errors.Is(err, os.ErrNotExist) {
 		var err error
-		c.ckp, err = generateCertificate()
+		c.Ckp, err = generateCertificate()
 		if err != nil {
 			return errors.Wrap(err, "error generating certificate")
 		}
@@ -117,7 +127,7 @@ func (c *Configuration) openKeyStore() error {
 
 func createDefaultConfiguration() Configuration {
 	return Configuration{
-		securityPolicy: "None",
+		SecurityPolicy: "None",
 	}
 }
 
diff --git a/plc4go/internal/opcua/Configuration_plc4xgen.go b/plc4go/internal/opcua/Configuration_plc4xgen.go
index 3d7372ec9a..359d1c1b49 100644
--- a/plc4go/internal/opcua/Configuration_plc4xgen.go
+++ b/plc4go/internal/opcua/Configuration_plc4xgen.go
@@ -43,36 +43,36 @@ func (d *Configuration) SerializeWithWriteBuffer(ctx context.Context, writeBuffe
 		return err
 	}
 
-	if err := writeBuffer.WriteString("code", uint32(len(d.code)*8), "UTF-8", d.code); err != nil {
+	if err := writeBuffer.WriteString("code", uint32(len(d.Code)*8), "UTF-8", d.Code); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("host", uint32(len(d.host)*8), "UTF-8", d.host); err != nil {
+	if err := writeBuffer.WriteString("host", uint32(len(d.Host)*8), "UTF-8", d.Host); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("port", uint32(len(d.port)*8), "UTF-8", d.port); err != nil {
+	if err := writeBuffer.WriteString("port", uint32(len(d.Port)*8), "UTF-8", d.Port); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("endpoint", uint32(len(d.endpoint)*8), "UTF-8", d.endpoint); err != nil {
+	if err := writeBuffer.WriteString("endpoint", uint32(len(d.Endpoint)*8), "UTF-8", d.Endpoint); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("transportEndpoint", uint32(len(d.transportEndpoint)*8), "UTF-8", d.transportEndpoint); err != nil {
+	if err := writeBuffer.WriteString("transportEndpoint", uint32(len(d.TransportEndpoint)*8), "UTF-8", d.TransportEndpoint); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("params", uint32(len(d.params)*8), "UTF-8", d.params); err != nil {
+	if err := writeBuffer.WriteString("params", uint32(len(d.Params)*8), "UTF-8", d.Params); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteBit("isEncrypted", d.isEncrypted); err != nil {
+	if err := writeBuffer.WriteBit("isEncrypted", d.IsEncrypted); err != nil {
 		return err
 	}
 
-	if d.thumbprint != nil {
-		if serializableField, ok := d.thumbprint.(utils.Serializable); ok {
+	if d.Thumbprint != nil {
+		if serializableField, ok := d.Thumbprint.(utils.Serializable); ok {
 			if err := writeBuffer.PushContext("thumbprint"); err != nil {
 				return err
 			}
@@ -83,45 +83,45 @@ func (d *Configuration) SerializeWithWriteBuffer(ctx context.Context, writeBuffe
 				return err
 			}
 		} else {
-			stringValue := fmt.Sprintf("%v", d.thumbprint)
+			stringValue := fmt.Sprintf("%v", d.Thumbprint)
 			if err := writeBuffer.WriteString("thumbprint", uint32(len(stringValue)*8), "UTF-8", stringValue); err != nil {
 				return err
 			}
 		}
 	}
-	if err := writeBuffer.WriteByteArray("senderCertificate", d.senderCertificate); err != nil {
+	if err := writeBuffer.WriteByteArray("senderCertificate", d.SenderCertificate); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteBit("discovery", d.discovery); err != nil {
+	if err := writeBuffer.WriteBit("discovery", d.Discovery); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("username", uint32(len(d.username)*8), "UTF-8", d.username); err != nil {
+	if err := writeBuffer.WriteString("username", uint32(len(d.Username)*8), "UTF-8", d.Username); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("password", uint32(len(d.password)*8), "UTF-8", d.password); err != nil {
+	if err := writeBuffer.WriteString("password", uint32(len(d.Password)*8), "UTF-8", d.Password); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("securityPolicy", uint32(len(d.securityPolicy)*8), "UTF-8", d.securityPolicy); err != nil {
+	if err := writeBuffer.WriteString("securityPolicy", uint32(len(d.SecurityPolicy)*8), "UTF-8", d.SecurityPolicy); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("keyStoreFile", uint32(len(d.keyStoreFile)*8), "UTF-8", d.keyStoreFile); err != nil {
+	if err := writeBuffer.WriteString("keyStoreFile", uint32(len(d.KeyStoreFile)*8), "UTF-8", d.KeyStoreFile); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("certDirectory", uint32(len(d.certDirectory)*8), "UTF-8", d.certDirectory); err != nil {
+	if err := writeBuffer.WriteString("certDirectory", uint32(len(d.CertDirectory)*8), "UTF-8", d.CertDirectory); err != nil {
 		return err
 	}
 
-	if err := writeBuffer.WriteString("keyStorePassword", uint32(len(d.keyStorePassword)*8), "UTF-8", d.keyStorePassword); err != nil {
+	if err := writeBuffer.WriteString("keyStorePassword", uint32(len(d.KeyStorePassword)*8), "UTF-8", d.KeyStorePassword); err != nil {
 		return err
 	}
 	{
-		_value := fmt.Sprintf("%v", d.ckp)
+		_value := fmt.Sprintf("%v", d.Ckp)
 
 		if err := writeBuffer.WriteString("ckp", uint32(len(_value)*8), "UTF-8", _value); err != nil {
 			return err
diff --git a/plc4go/internal/opcua/Driver.go b/plc4go/internal/opcua/Driver.go
index 76ba9f12aa..b017fd39ed 100644
--- a/plc4go/internal/opcua/Driver.go
+++ b/plc4go/internal/opcua/Driver.go
@@ -125,17 +125,17 @@ func (d *Driver) GetConnectionWithContext(ctx context.Context, transportUrl url.
 	if err != nil {
 		return d.reportError(errors.Wrap(err, "can't parse options"))
 	}
-	configuration.host = transportHost
-	configuration.port = transportPort
-	configuration.transportEndpoint = transportEndpoint
+	configuration.Host = transportHost
+	configuration.Port = transportPort
+	configuration.TransportEndpoint = transportEndpoint
 	portAddition := ""
 	if transportPort != "" {
 		portAddition += ":" + transportPort
 	}
-	configuration.endpoint = "opc." + transportCode + "://" + transportHost + portAddition + "" + transportEndpoint
-	d.log.Debug().Stringer("configuration", &configuration).Msg("working with configurartion")
+	configuration.Endpoint = "opc." + transportCode + "://" + transportHost + portAddition + "" + transportEndpoint
+	d.log.Debug().Stringer("configuration", &configuration).Msg("working with configuration")
 
-	if securityPolicy := configuration.securityPolicy; securityPolicy != "" && securityPolicy != "None" {
+	if securityPolicy := configuration.SecurityPolicy; securityPolicy != "" && securityPolicy != "None" {
 		d.log.Trace().Str("securityPolicy", securityPolicy).Msg("working with security policy")
 		if err := configuration.openKeyStore(); err != nil {
 			return d.reportError(errors.Wrap(err, "error opening key store"))
diff --git a/plc4go/internal/opcua/DriverContext.go b/plc4go/internal/opcua/DriverContext.go
index bf0fcc2c46..284aa9f01e 100644
--- a/plc4go/internal/opcua/DriverContext.go
+++ b/plc4go/internal/opcua/DriverContext.go
@@ -29,7 +29,7 @@ type DriverContext struct {
 
 func NewDriverContext(configuration Configuration) DriverContext {
 	return DriverContext{
-		fireDiscoverEvent:            configuration.isEncrypted,
+		fireDiscoverEvent:            configuration.IsEncrypted,
 		awaitSetupComplete:           true,
 		awaitDisconnectComplete:      true,
 		awaitSessionDiscoverComplete: true,
diff --git a/plc4go/internal/opcua/SecureChannel.go b/plc4go/internal/opcua/SecureChannel.go
index 8b713d1b5e..eed6a455c0 100644
--- a/plc4go/internal/opcua/SecureChannel.go
+++ b/plc4go/internal/opcua/SecureChannel.go
@@ -128,14 +128,14 @@ type SecureChannel struct {
 func NewSecureChannel(log zerolog.Logger, ctx DriverContext, configuration Configuration) *SecureChannel {
 	s := &SecureChannel{
 		configuration:             configuration,
-		endpoint:                  readWriteModel.NewPascalString(configuration.endpoint),
-		username:                  configuration.username,
-		password:                  configuration.password,
-		securityPolicy:            "http://opcfoundation.org/UA/SecurityPolicy#" + configuration.securityPolicy,
+		endpoint:                  readWriteModel.NewPascalString(configuration.Endpoint),
+		username:                  configuration.Username,
+		password:                  configuration.Password,
+		securityPolicy:            "http://opcfoundation.org/UA/SecurityPolicy#" + configuration.SecurityPolicy,
 		sessionName:               "UaSession:" + APPLICATION_TEXT.GetStringValue() + ":" + uniuri.NewLen(20),
 		authenticationToken:       readWriteModel.NewNodeIdTwoByte(0),
 		clientNonce:               []byte(uniuri.NewLen(40)),
-		keyStoreFile:              configuration.keyStoreFile,
+		keyStoreFile:              configuration.KeyStoreFile,
 		channelTransactionManager: NewSecureChannelTransactionManager(log),
 		lifetime:                  DEFAULT_CONNECTION_LIFETIME,
 		log:                       log,
@@ -143,18 +143,18 @@ func NewSecureChannel(log zerolog.Logger, ctx DriverContext, configuration Confi
 	s.requestHandleGenerator.Store(1)
 	s.channelId.Store(1)
 	s.tokenId.Store(1)
-	ckp := configuration.ckp
-	if configuration.securityPolicy == "Basic256Sha256" {
+	ckp := configuration.Ckp
+	if configuration.SecurityPolicy == "Basic256Sha256" {
 		//Sender Certificate gets populated during the 'discover' phase when encryption is enabled.
-		s.senderCertificate = configuration.senderCertificate
-		s.encryptionHandler = NewEncryptionHandler(s.log, ckp, s.senderCertificate, configuration.securityPolicy)
+		s.senderCertificate = configuration.SenderCertificate
+		s.encryptionHandler = NewEncryptionHandler(s.log, ckp, s.senderCertificate, configuration.SecurityPolicy)
 		certificate := ckp.getCertificate()
 		s.publicCertificate = readWriteModel.NewPascalByteString(int32(len(certificate.Raw)), certificate.Raw)
 		s.isEncrypted = true
 
-		s.thumbprint = configuration.thumbprint
+		s.thumbprint = configuration.Thumbprint
 	} else {
-		s.encryptionHandler = NewEncryptionHandler(s.log, ckp, s.senderCertificate, configuration.securityPolicy)
+		s.encryptionHandler = NewEncryptionHandler(s.log, ckp, s.senderCertificate, configuration.SecurityPolicy)
 		s.publicCertificate = NULL_BYTE_STRING
 		s.thumbprint = NULL_BYTE_STRING
 		s.isEncrypted = false
@@ -163,7 +163,7 @@ func NewSecureChannel(log zerolog.Logger, ctx DriverContext, configuration Confi
 	// Generate a list of endpoints we can use.
 	{
 		var err error
-		address, err := url.Parse("none://" + configuration.host)
+		address, err := url.Parse("none://" + configuration.Host)
 		if err == nil {
 			if names, lookupErr := net.LookupHost(address.Host); lookupErr == nil {
 				s.endpoints = append(s.endpoints, names[rand.Intn(len(names))])
@@ -611,11 +611,11 @@ func (s *SecureChannel) onConnectActivateSessionRequest(ctx context.Context, con
 	s.encryptionHandler.setServerCertificate(certificate)
 	s.senderNonce = sessionResponse.GetServerNonce().GetStringValue()
 	endpoints := make([]string, 3)
-	if address, err := url.Parse(s.configuration.host); err != nil {
+	if address, err := url.Parse(s.configuration.Host); err != nil {
 		if names, err := net.LookupAddr(address.Host); err != nil {
-			endpoints[0] = "opc.tcp://" + names[rand.Intn(len(names))] + ":" + s.configuration.port + s.configuration.transportEndpoint
+			endpoints[0] = "opc.tcp://" + names[rand.Intn(len(names))] + ":" + s.configuration.Port + s.configuration.TransportEndpoint
 		}
-		endpoints[1] = "opc.tcp://" + address.Hostname() + ":" + s.configuration.port + s.configuration.transportEndpoint
+		endpoints[1] = "opc.tcp://" + address.Hostname() + ":" + s.configuration.Port + s.configuration.TransportEndpoint
 		//endpoints[2] = "opc.tcp://" + address.getCanonicalHostName() + ":" + s.configuration.getPort() + s.configuration.transportEndpoint// TODO: not sure how to get that in golang
 	}
 
@@ -1190,11 +1190,11 @@ func (s *SecureChannel) onDiscoverGetEndpointsRequest(ctx context.Context, codec
 						endpointDescription := endpoint.(readWriteModel.EndpointDescription)
 						if endpointDescription.GetEndpointUrl().GetStringValue() == (s.endpoint.GetStringValue()) && endpointDescription.GetSecurityPolicyUri().GetStringValue() == (s.securityPolicy) {
 							s.log.Info().Str("stringValue", s.endpoint.GetStringValue()).Msg("Found OPC UA endpoint")
-							s.configuration.senderCertificate = endpointDescription.GetServerCertificate().GetStringValue()
+							s.configuration.SenderCertificate = endpointDescription.GetServerCertificate().GetStringValue()
 						}
 					}
 
-					digest := sha1.Sum(s.configuration.senderCertificate)
+					digest := sha1.Sum(s.configuration.SenderCertificate)
 					s.thumbprint = readWriteModel.NewPascalByteString(int32(len(digest)), digest[:])
 
 					go s.onDiscoverCloseSecureChannel(ctx, codec, response)
@@ -1538,20 +1538,20 @@ func (s *SecureChannel) isEndpoint(endpoint readWriteModel.EndpointDescription)
 		Str("transportEndpoint", matches["transportEndpoint"]).
 		Msg("Using Endpoint")
 
-	if s.configuration.discovery && !slices.Contains(s.endpoints, matches["transportHost"]) {
+	if s.configuration.Discovery && !slices.Contains(s.endpoints, matches["transportHost"]) {
 		return false
 	}
 
-	if s.configuration.port != matches["transportPort"] {
+	if s.configuration.Port != matches["transportPort"] {
 		return false
 	}
 
-	if s.configuration.transportEndpoint != matches["transportEndpoint"] {
+	if s.configuration.TransportEndpoint != matches["transportEndpoint"] {
 		return false
 	}
 
-	if !s.configuration.discovery {
-		s.configuration.host = matches["transportHost"]
+	if !s.configuration.Discovery {
+		s.configuration.Host = matches["transportHost"]
 	}
 
 	return true