You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2023/06/19 13:51:00 UTC

[plc4x] 05/05: fix(plc4go/spi): fix race issues in request transaction

This is an automated email from the ASF dual-hosted git repository.

sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git

commit 0458529ee1acfda55ae7ffaf75d6c95045a06fea
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Mon Jun 19 15:49:48 2023 +0200

    fix(plc4go/spi): fix race issues in request transaction
---
 plc4go/spi/transactions/RequestTransaction.go      |  25 +++-
 .../spi/transactions/RequestTransactionManager.go  |   6 +-
 .../transactions/RequestTransactionManager_test.go |  26 ++--
 plc4go/spi/transactions/RequestTransaction_test.go | 154 +++++++++++----------
 .../transactions/requestTransaction_plc4xgen.go    |   7 +-
 5 files changed, 121 insertions(+), 97 deletions(-)

diff --git a/plc4go/spi/transactions/RequestTransaction.go b/plc4go/spi/transactions/RequestTransaction.go
index 18f88c7293..760a0d6a4c 100644
--- a/plc4go/spi/transactions/RequestTransaction.go
+++ b/plc4go/spi/transactions/RequestTransaction.go
@@ -22,12 +22,15 @@ package transactions
 import (
 	"context"
 	"fmt"
+	"sync"
+	"sync/atomic"
+	"time"
+
 	"github.com/apache/plc4x/plc4go/spi/pool"
+
 	"github.com/pkg/errors"
 	"github.com/rs/zerolog"
 	"github.com/rs/zerolog/log"
-	"sync"
-	"time"
 )
 
 // RequestTransaction represents a transaction
@@ -58,7 +61,7 @@ type requestTransaction struct {
 
 	/** The initial operation to perform to kick off the request */
 	operation        pool.Runnable `ignore:"true"` // TODO: maybe we can treat this as a function some day if we are able to check the definition in gen
-	completionFuture pool.CompletionFuture
+	completionFuture atomic.Pointer[pool.CompletionFuture]
 
 	stateChangeMutex sync.Mutex
 	completed        bool
@@ -66,6 +69,18 @@ type requestTransaction struct {
 	transactionLog zerolog.Logger `ignore:"true"`
 }
 
+func (t *requestTransaction) setCompletionFuture(completionFuture pool.CompletionFuture) {
+	t.completionFuture.Store(&completionFuture)
+}
+
+func (t *requestTransaction) getCompletionFuture() pool.CompletionFuture {
+	completionFutureLoaded := t.completionFuture.Load()
+	if completionFutureLoaded == nil {
+		return nil
+	}
+	return *completionFutureLoaded
+}
+
 //
 // Internal section
 //
@@ -118,14 +133,14 @@ func (t *requestTransaction) AwaitCompletion(ctx context.Context) error {
 	t.transactionLog.Trace().Int32("transactionId", t.transactionId).Msg("Awaiting completion")
 	timeout, cancelFunc := context.WithTimeout(ctx, time.Minute*30) // This is intentionally set very high
 	defer cancelFunc()
-	for t.completionFuture == nil {
+	for t.getCompletionFuture() == nil {
 		time.Sleep(time.Millisecond * 10)
 		if err := timeout.Err(); err != nil {
 			log.Error().Msg("Timout after a long time. This means something is very of here")
 			return errors.Wrap(err, "Error waiting for completion future to be set")
 		}
 	}
-	if err := t.completionFuture.AwaitCompletion(ctx); err != nil {
+	if err := t.getCompletionFuture().AwaitCompletion(ctx); err != nil {
 		t.transactionLog.Trace().Int32("transactionId", t.transactionId).Msg("Errored")
 		return err
 	}
diff --git a/plc4go/spi/transactions/RequestTransactionManager.go b/plc4go/spi/transactions/RequestTransactionManager.go
index 83551e0805..9fc80d6f4d 100644
--- a/plc4go/spi/transactions/RequestTransactionManager.go
+++ b/plc4go/spi/transactions/RequestTransactionManager.go
@@ -169,7 +169,7 @@ func (r *requestTransactionManager) processWorklog() {
 		r.log.Debug().Msgf("Handling next\n%v\n. (Adding to running requests (length: %d))", next, len(r.runningRequests))
 		r.runningRequests = append(r.runningRequests, next)
 		completionFuture := r.executor.Submit(context.Background(), next.transactionId, next.operation)
-		next.completionFuture = completionFuture
+		next.setCompletionFuture(completionFuture)
 		r.workLog.Remove(front)
 	}
 }
@@ -190,7 +190,7 @@ func (r *requestTransactionManager) StartTransaction() RequestTransaction {
 	}
 	if r.shutdown.Load() {
 		transaction.completed = true
-		transaction.completionFuture = &completedFuture{errors.New("request transaction manager in shutdown")}
+		transaction.setCompletionFuture(&completedFuture{errors.New("request transaction manager in shutdown")})
 	}
 	return transaction
 }
@@ -203,7 +203,7 @@ func (r *requestTransactionManager) getNumberOfActiveRequests() int {
 
 func (r *requestTransactionManager) failRequest(transaction *requestTransaction, err error) error {
 	// Try to fail it!
-	transaction.completionFuture.Cancel(true, err)
+	transaction.getCompletionFuture().Cancel(true, err)
 	// End it
 	return r.endRequest(transaction)
 }
diff --git a/plc4go/spi/transactions/RequestTransactionManager_test.go b/plc4go/spi/transactions/RequestTransactionManager_test.go
index 63041083ea..b80299c202 100644
--- a/plc4go/spi/transactions/RequestTransactionManager_test.go
+++ b/plc4go/spi/transactions/RequestTransactionManager_test.go
@@ -280,10 +280,11 @@ func Test_requestTransactionManager_failRequest(t *testing.T) {
 				transaction: &requestTransaction{},
 			},
 			mockSetup: func(t *testing.T, fields *fields, args *args) {
-				completionFuture := NewMockCompletionFuture(t)
-				expect := completionFuture.EXPECT()
+				completionFutureMock := NewMockCompletionFuture(t)
+				expect := completionFutureMock.EXPECT()
 				expect.Cancel(true, nil).Return()
-				args.transaction.completionFuture = completionFuture
+				var completionFuture pool.CompletionFuture = completionFutureMock
+				args.transaction.completionFuture.Store(&completionFuture)
 			},
 			wantErr: true,
 		},
@@ -374,14 +375,17 @@ func Test_requestTransactionManager_processWorklog(t *testing.T) {
 				numberOfConcurrentRequests: 100,
 				workLog: func() list.List {
 					l := list.New()
-					l.PushBack(&requestTransaction{
-						transactionId:    1,
-						completionFuture: NewMockCompletionFuture(t),
-					})
-					l.PushBack(&requestTransaction{
-						transactionId:    2,
-						completionFuture: NewMockCompletionFuture(t),
-					})
+					var completionFuture pool.CompletionFuture = NewMockCompletionFuture(t)
+					r1 := &requestTransaction{
+						transactionId: 1,
+					}
+					r1.completionFuture.Store(&completionFuture)
+					l.PushBack(r1)
+					r2 := &requestTransaction{
+						transactionId: 2,
+					}
+					r2.completionFuture.Store(&completionFuture)
+					l.PushBack(r2)
 					return *l
 				}(),
 				executor: sharedExecutorInstance,
diff --git a/plc4go/spi/transactions/RequestTransaction_test.go b/plc4go/spi/transactions/RequestTransaction_test.go
index 4288d703ba..00c68fe1f9 100644
--- a/plc4go/spi/transactions/RequestTransaction_test.go
+++ b/plc4go/spi/transactions/RequestTransaction_test.go
@@ -35,11 +35,10 @@ import (
 
 func Test_requestTransaction_EndRequest(t1 *testing.T) {
 	type fields struct {
-		parent           *requestTransactionManager
-		transactionId    int32
-		operation        pool.Runnable
-		completionFuture pool.CompletionFuture
-		completed        bool
+		parent        *requestTransactionManager
+		transactionId int32
+		operation     pool.Runnable
+		completed     bool
 	}
 	tests := []struct {
 		name    string
@@ -65,12 +64,11 @@ func Test_requestTransaction_EndRequest(t1 *testing.T) {
 	for _, tt := range tests {
 		t1.Run(tt.name, func(t1 *testing.T) {
 			t := &requestTransaction{
-				parent:           tt.fields.parent,
-				transactionId:    tt.fields.transactionId,
-				operation:        tt.fields.operation,
-				completionFuture: tt.fields.completionFuture,
-				transactionLog:   testutils.ProduceTestingLogger(t1),
-				completed:        tt.fields.completed,
+				parent:         tt.fields.parent,
+				transactionId:  tt.fields.transactionId,
+				operation:      tt.fields.operation,
+				transactionLog: testutils.ProduceTestingLogger(t1),
+				completed:      tt.fields.completed,
 			}
 			if err := t.EndRequest(); (err != nil) != tt.wantErr {
 				t1.Errorf("EndRequest() error = %v, wantErr %v", err, tt.wantErr)
@@ -81,33 +79,34 @@ func Test_requestTransaction_EndRequest(t1 *testing.T) {
 
 func Test_requestTransaction_FailRequest(t1 *testing.T) {
 	type fields struct {
-		parent           *requestTransactionManager
-		transactionId    int32
-		operation        pool.Runnable
-		completionFuture pool.CompletionFuture
-		transactionLog   zerolog.Logger
-		completed        bool
+		parent         *requestTransactionManager
+		transactionId  int32
+		operation      pool.Runnable
+		transactionLog zerolog.Logger
+		completed      bool
 	}
 	type args struct {
 		err error
 	}
 	tests := []struct {
-		name      string
-		fields    fields
-		args      args
-		mockSetup func(t *testing.T, fields *fields, args *args)
-		wantErr   assert.ErrorAssertionFunc
+		name        string
+		fields      fields
+		args        args
+		mockSetup   func(t *testing.T, fields *fields, args *args)
+		manipulator func(t *testing.T, transaction *requestTransaction)
+		wantErr     assert.ErrorAssertionFunc
 	}{
 		{
 			name: "just fail it",
 			fields: fields{
 				parent: &requestTransactionManager{},
 			},
-			mockSetup: func(t *testing.T, fields *fields, args *args) {
-				completionFuture := NewMockCompletionFuture(t)
-				expect := completionFuture.EXPECT()
+			manipulator: func(t *testing.T, transaction *requestTransaction) {
+				completionFutureMock := NewMockCompletionFuture(t)
+				expect := completionFutureMock.EXPECT()
 				expect.Cancel(true, nil).Return()
-				fields.completionFuture = completionFuture
+				var completionFuture pool.CompletionFuture = completionFutureMock
+				transaction.completionFuture.Store(&completionFuture)
 			},
 			wantErr: assert.Error,
 		},
@@ -129,32 +128,37 @@ func Test_requestTransaction_FailRequest(t1 *testing.T) {
 				tt.mockSetup(t, &tt.fields, &tt.args)
 			}
 			r := &requestTransaction{
-				parent:           tt.fields.parent,
-				transactionId:    tt.fields.transactionId,
-				operation:        tt.fields.operation,
-				completionFuture: tt.fields.completionFuture,
-				transactionLog:   tt.fields.transactionLog,
-				completed:        tt.fields.completed,
+				parent:         tt.fields.parent,
+				transactionId:  tt.fields.transactionId,
+				operation:      tt.fields.operation,
+				transactionLog: tt.fields.transactionLog,
+				completed:      tt.fields.completed,
+			}
+			if tt.manipulator != nil {
+				tt.manipulator(t, r)
 			}
 			tt.wantErr(t, r.FailRequest(tt.args.err), "FailRequest() error = %v", tt.args.err)
 		})
 	}
 }
 
-func Test_requestTransaction_String(t1 *testing.T) {
+func Test_requestTransaction_String(t *testing.T) {
 	type fields struct {
-		parent           *requestTransactionManager
-		transactionId    int32
-		operation        pool.Runnable
-		completionFuture pool.CompletionFuture
+		parent        *requestTransactionManager
+		transactionId int32
+		operation     pool.Runnable
 	}
 	tests := []struct {
-		name   string
-		fields fields
-		want   string
+		name        string
+		fields      fields
+		manipulator func(t *testing.T, transaction *requestTransaction)
+		want        string
 	}{
 		{
 			name: "give a string",
+			manipulator: func(t *testing.T, transaction *requestTransaction) {
+				transaction.setCompletionFuture(nil)
+			},
 			want: `
 ╔═requestTransaction═════════╗
 ║╔═transactionId╗╔═completed╗║
@@ -164,15 +168,17 @@ func Test_requestTransaction_String(t1 *testing.T) {
 		},
 	}
 	for _, tt := range tests {
-		t1.Run(tt.name, func(t1 *testing.T) {
-			t := &requestTransaction{
-				parent:           tt.fields.parent,
-				transactionId:    tt.fields.transactionId,
-				operation:        tt.fields.operation,
-				completionFuture: tt.fields.completionFuture,
-				transactionLog:   testutils.ProduceTestingLogger(t1),
+		t.Run(tt.name, func(t1 *testing.T) {
+			_t := &requestTransaction{
+				parent:         tt.fields.parent,
+				transactionId:  tt.fields.transactionId,
+				operation:      tt.fields.operation,
+				transactionLog: testutils.ProduceTestingLogger(t1),
 			}
-			if got := t.String(); got != tt.want {
+			if tt.manipulator != nil {
+				tt.manipulator(t, _t)
+			}
+			if got := _t.String(); got != tt.want {
 				t1.Errorf("String() = \n%v, want \n%v", got, tt.want)
 			}
 		})
@@ -181,12 +187,11 @@ func Test_requestTransaction_String(t1 *testing.T) {
 
 func Test_requestTransaction_Submit(t1 *testing.T) {
 	type fields struct {
-		parent           *requestTransactionManager
-		transactionId    int32
-		operation        pool.Runnable
-		completionFuture pool.CompletionFuture
-		transactionLog   zerolog.Logger
-		completed        bool
+		parent         *requestTransactionManager
+		transactionId  int32
+		operation      pool.Runnable
+		transactionLog zerolog.Logger
+		completed      bool
 	}
 	type args struct {
 		operation RequestTransactionRunnable
@@ -240,12 +245,11 @@ func Test_requestTransaction_Submit(t1 *testing.T) {
 	for _, tt := range tests {
 		t1.Run(tt.name, func(t1 *testing.T) {
 			t := &requestTransaction{
-				parent:           tt.fields.parent,
-				transactionId:    tt.fields.transactionId,
-				operation:        tt.fields.operation,
-				completionFuture: tt.fields.completionFuture,
-				transactionLog:   tt.fields.transactionLog,
-				completed:        tt.fields.completed,
+				parent:         tt.fields.parent,
+				transactionId:  tt.fields.transactionId,
+				operation:      tt.fields.operation,
+				transactionLog: tt.fields.transactionLog,
+				completed:      tt.fields.completed,
 			}
 			t.Submit(tt.args.operation)
 			t.operation()
@@ -255,10 +259,9 @@ func Test_requestTransaction_Submit(t1 *testing.T) {
 
 func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
 	type fields struct {
-		parent           *requestTransactionManager
-		transactionId    int32
-		operation        pool.Runnable
-		completionFuture pool.CompletionFuture
+		parent        *requestTransactionManager
+		transactionId int32
+		operation     pool.Runnable
 	}
 	type args struct {
 		ctx context.Context
@@ -285,13 +288,12 @@ func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
 					return ctx
 				}(),
 			},
-			mockSetup: func(t *testing.T, fields *fields, args *args) {
-				completionFuture := NewMockCompletionFuture(t)
-				expect := completionFuture.EXPECT()
-				expect.AwaitCompletion(mock.Anything).Return(nil)
-				fields.completionFuture = completionFuture
-			},
 			manipulator: func(t *testing.T, transaction *requestTransaction) {
+				completionFutureMock := NewMockCompletionFuture(t)
+				expect := completionFutureMock.EXPECT()
+				expect.AwaitCompletion(mock.Anything).Return(nil)
+				var completionFuture pool.CompletionFuture = completionFutureMock
+				transaction.completionFuture.Store(&completionFuture)
 				go func() {
 					time.Sleep(100 * time.Millisecond)
 					r := transaction.parent
@@ -308,11 +310,13 @@ func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
 				tt.mockSetup(t1, &tt.fields, &tt.args)
 			}
 			t := &requestTransaction{
-				parent:           tt.fields.parent,
-				transactionId:    tt.fields.transactionId,
-				operation:        tt.fields.operation,
-				completionFuture: tt.fields.completionFuture,
-				transactionLog:   testutils.ProduceTestingLogger(t1),
+				parent:         tt.fields.parent,
+				transactionId:  tt.fields.transactionId,
+				operation:      tt.fields.operation,
+				transactionLog: testutils.ProduceTestingLogger(t1),
+			}
+			if tt.manipulator != nil {
+				tt.manipulator(t1, t)
 			}
 			if err := t.AwaitCompletion(tt.args.ctx); (err != nil) != tt.wantErr {
 				t1.Errorf("AwaitCompletion() error = %v, wantErr %v", err, tt.wantErr)
diff --git a/plc4go/spi/transactions/requestTransaction_plc4xgen.go b/plc4go/spi/transactions/requestTransaction_plc4xgen.go
index d6dd6afa75..733c59fde4 100644
--- a/plc4go/spi/transactions/requestTransaction_plc4xgen.go
+++ b/plc4go/spi/transactions/requestTransaction_plc4xgen.go
@@ -47,8 +47,9 @@ func (d *requestTransaction) SerializeWithWriteBuffer(ctx context.Context, write
 		return err
 	}
 
-	if d.completionFuture != nil {
-		if serializableField, ok := d.completionFuture.(utils.Serializable); ok {
+	if completionFutureLoaded := d.completionFuture.Load(); completionFutureLoaded != nil && *completionFutureLoaded != nil {
+		completionFuture := *completionFutureLoaded
+		if serializableField, ok := completionFuture.(utils.Serializable); ok {
 			if err := writeBuffer.PushContext("completionFuture"); err != nil {
 				return err
 			}
@@ -59,7 +60,7 @@ func (d *requestTransaction) SerializeWithWriteBuffer(ctx context.Context, write
 				return err
 			}
 		} else {
-			stringValue := fmt.Sprintf("%v", d.completionFuture)
+			stringValue := fmt.Sprintf("%v", completionFuture)
 			if err := writeBuffer.WriteString("completionFuture", uint32(len(stringValue)*8), "UTF-8", stringValue); err != nil {
 				return err
 			}