You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@plc4x.apache.org by sr...@apache.org on 2023/06/19 13:51:00 UTC
[plc4x] 05/05: fix(plc4go/spi): fix race issues in request transaction
This is an automated email from the ASF dual-hosted git repository.
sruehl pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/plc4x.git
commit 0458529ee1acfda55ae7ffaf75d6c95045a06fea
Author: Sebastian Rühl <sr...@apache.org>
AuthorDate: Mon Jun 19 15:49:48 2023 +0200
fix(plc4go/spi): fix race issues in request transaction
---
plc4go/spi/transactions/RequestTransaction.go | 25 +++-
.../spi/transactions/RequestTransactionManager.go | 6 +-
.../transactions/RequestTransactionManager_test.go | 26 ++--
plc4go/spi/transactions/RequestTransaction_test.go | 154 +++++++++++----------
.../transactions/requestTransaction_plc4xgen.go | 7 +-
5 files changed, 121 insertions(+), 97 deletions(-)
diff --git a/plc4go/spi/transactions/RequestTransaction.go b/plc4go/spi/transactions/RequestTransaction.go
index 18f88c7293..760a0d6a4c 100644
--- a/plc4go/spi/transactions/RequestTransaction.go
+++ b/plc4go/spi/transactions/RequestTransaction.go
@@ -22,12 +22,15 @@ package transactions
import (
"context"
"fmt"
+ "sync"
+ "sync/atomic"
+ "time"
+
"github.com/apache/plc4x/plc4go/spi/pool"
+
"github.com/pkg/errors"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
- "sync"
- "time"
)
// RequestTransaction represents a transaction
@@ -58,7 +61,7 @@ type requestTransaction struct {
/** The initial operation to perform to kick off the request */
operation pool.Runnable `ignore:"true"` // TODO: maybe we can treat this as a function some day if we are able to check the definition in gen
- completionFuture pool.CompletionFuture
+ completionFuture atomic.Pointer[pool.CompletionFuture]
stateChangeMutex sync.Mutex
completed bool
@@ -66,6 +69,18 @@ type requestTransaction struct {
transactionLog zerolog.Logger `ignore:"true"`
}
+func (t *requestTransaction) setCompletionFuture(completionFuture pool.CompletionFuture) {
+ t.completionFuture.Store(&completionFuture)
+}
+
+func (t *requestTransaction) getCompletionFuture() pool.CompletionFuture {
+ completionFutureLoaded := t.completionFuture.Load()
+ if completionFutureLoaded == nil {
+ return nil
+ }
+ return *completionFutureLoaded
+}
+
//
// Internal section
//
@@ -118,14 +133,14 @@ func (t *requestTransaction) AwaitCompletion(ctx context.Context) error {
t.transactionLog.Trace().Int32("transactionId", t.transactionId).Msg("Awaiting completion")
timeout, cancelFunc := context.WithTimeout(ctx, time.Minute*30) // This is intentionally set very high
defer cancelFunc()
- for t.completionFuture == nil {
+ for t.getCompletionFuture() == nil {
time.Sleep(time.Millisecond * 10)
if err := timeout.Err(); err != nil {
log.Error().Msg("Timout after a long time. This means something is very of here")
return errors.Wrap(err, "Error waiting for completion future to be set")
}
}
- if err := t.completionFuture.AwaitCompletion(ctx); err != nil {
+ if err := t.getCompletionFuture().AwaitCompletion(ctx); err != nil {
t.transactionLog.Trace().Int32("transactionId", t.transactionId).Msg("Errored")
return err
}
diff --git a/plc4go/spi/transactions/RequestTransactionManager.go b/plc4go/spi/transactions/RequestTransactionManager.go
index 83551e0805..9fc80d6f4d 100644
--- a/plc4go/spi/transactions/RequestTransactionManager.go
+++ b/plc4go/spi/transactions/RequestTransactionManager.go
@@ -169,7 +169,7 @@ func (r *requestTransactionManager) processWorklog() {
r.log.Debug().Msgf("Handling next\n%v\n. (Adding to running requests (length: %d))", next, len(r.runningRequests))
r.runningRequests = append(r.runningRequests, next)
completionFuture := r.executor.Submit(context.Background(), next.transactionId, next.operation)
- next.completionFuture = completionFuture
+ next.setCompletionFuture(completionFuture)
r.workLog.Remove(front)
}
}
@@ -190,7 +190,7 @@ func (r *requestTransactionManager) StartTransaction() RequestTransaction {
}
if r.shutdown.Load() {
transaction.completed = true
- transaction.completionFuture = &completedFuture{errors.New("request transaction manager in shutdown")}
+ transaction.setCompletionFuture(&completedFuture{errors.New("request transaction manager in shutdown")})
}
return transaction
}
@@ -203,7 +203,7 @@ func (r *requestTransactionManager) getNumberOfActiveRequests() int {
func (r *requestTransactionManager) failRequest(transaction *requestTransaction, err error) error {
// Try to fail it!
- transaction.completionFuture.Cancel(true, err)
+ transaction.getCompletionFuture().Cancel(true, err)
// End it
return r.endRequest(transaction)
}
diff --git a/plc4go/spi/transactions/RequestTransactionManager_test.go b/plc4go/spi/transactions/RequestTransactionManager_test.go
index 63041083ea..b80299c202 100644
--- a/plc4go/spi/transactions/RequestTransactionManager_test.go
+++ b/plc4go/spi/transactions/RequestTransactionManager_test.go
@@ -280,10 +280,11 @@ func Test_requestTransactionManager_failRequest(t *testing.T) {
transaction: &requestTransaction{},
},
mockSetup: func(t *testing.T, fields *fields, args *args) {
- completionFuture := NewMockCompletionFuture(t)
- expect := completionFuture.EXPECT()
+ completionFutureMock := NewMockCompletionFuture(t)
+ expect := completionFutureMock.EXPECT()
expect.Cancel(true, nil).Return()
- args.transaction.completionFuture = completionFuture
+ var completionFuture pool.CompletionFuture = completionFutureMock
+ args.transaction.completionFuture.Store(&completionFuture)
},
wantErr: true,
},
@@ -374,14 +375,17 @@ func Test_requestTransactionManager_processWorklog(t *testing.T) {
numberOfConcurrentRequests: 100,
workLog: func() list.List {
l := list.New()
- l.PushBack(&requestTransaction{
- transactionId: 1,
- completionFuture: NewMockCompletionFuture(t),
- })
- l.PushBack(&requestTransaction{
- transactionId: 2,
- completionFuture: NewMockCompletionFuture(t),
- })
+ var completionFuture pool.CompletionFuture = NewMockCompletionFuture(t)
+ r1 := &requestTransaction{
+ transactionId: 1,
+ }
+ r1.completionFuture.Store(&completionFuture)
+ l.PushBack(r1)
+ r2 := &requestTransaction{
+ transactionId: 2,
+ }
+ r2.completionFuture.Store(&completionFuture)
+ l.PushBack(r2)
return *l
}(),
executor: sharedExecutorInstance,
diff --git a/plc4go/spi/transactions/RequestTransaction_test.go b/plc4go/spi/transactions/RequestTransaction_test.go
index 4288d703ba..00c68fe1f9 100644
--- a/plc4go/spi/transactions/RequestTransaction_test.go
+++ b/plc4go/spi/transactions/RequestTransaction_test.go
@@ -35,11 +35,10 @@ import (
func Test_requestTransaction_EndRequest(t1 *testing.T) {
type fields struct {
- parent *requestTransactionManager
- transactionId int32
- operation pool.Runnable
- completionFuture pool.CompletionFuture
- completed bool
+ parent *requestTransactionManager
+ transactionId int32
+ operation pool.Runnable
+ completed bool
}
tests := []struct {
name string
@@ -65,12 +64,11 @@ func Test_requestTransaction_EndRequest(t1 *testing.T) {
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &requestTransaction{
- parent: tt.fields.parent,
- transactionId: tt.fields.transactionId,
- operation: tt.fields.operation,
- completionFuture: tt.fields.completionFuture,
- transactionLog: testutils.ProduceTestingLogger(t1),
- completed: tt.fields.completed,
+ parent: tt.fields.parent,
+ transactionId: tt.fields.transactionId,
+ operation: tt.fields.operation,
+ transactionLog: testutils.ProduceTestingLogger(t1),
+ completed: tt.fields.completed,
}
if err := t.EndRequest(); (err != nil) != tt.wantErr {
t1.Errorf("EndRequest() error = %v, wantErr %v", err, tt.wantErr)
@@ -81,33 +79,34 @@ func Test_requestTransaction_EndRequest(t1 *testing.T) {
func Test_requestTransaction_FailRequest(t1 *testing.T) {
type fields struct {
- parent *requestTransactionManager
- transactionId int32
- operation pool.Runnable
- completionFuture pool.CompletionFuture
- transactionLog zerolog.Logger
- completed bool
+ parent *requestTransactionManager
+ transactionId int32
+ operation pool.Runnable
+ transactionLog zerolog.Logger
+ completed bool
}
type args struct {
err error
}
tests := []struct {
- name string
- fields fields
- args args
- mockSetup func(t *testing.T, fields *fields, args *args)
- wantErr assert.ErrorAssertionFunc
+ name string
+ fields fields
+ args args
+ mockSetup func(t *testing.T, fields *fields, args *args)
+ manipulator func(t *testing.T, transaction *requestTransaction)
+ wantErr assert.ErrorAssertionFunc
}{
{
name: "just fail it",
fields: fields{
parent: &requestTransactionManager{},
},
- mockSetup: func(t *testing.T, fields *fields, args *args) {
- completionFuture := NewMockCompletionFuture(t)
- expect := completionFuture.EXPECT()
+ manipulator: func(t *testing.T, transaction *requestTransaction) {
+ completionFutureMock := NewMockCompletionFuture(t)
+ expect := completionFutureMock.EXPECT()
expect.Cancel(true, nil).Return()
- fields.completionFuture = completionFuture
+ var completionFuture pool.CompletionFuture = completionFutureMock
+ transaction.completionFuture.Store(&completionFuture)
},
wantErr: assert.Error,
},
@@ -129,32 +128,37 @@ func Test_requestTransaction_FailRequest(t1 *testing.T) {
tt.mockSetup(t, &tt.fields, &tt.args)
}
r := &requestTransaction{
- parent: tt.fields.parent,
- transactionId: tt.fields.transactionId,
- operation: tt.fields.operation,
- completionFuture: tt.fields.completionFuture,
- transactionLog: tt.fields.transactionLog,
- completed: tt.fields.completed,
+ parent: tt.fields.parent,
+ transactionId: tt.fields.transactionId,
+ operation: tt.fields.operation,
+ transactionLog: tt.fields.transactionLog,
+ completed: tt.fields.completed,
+ }
+ if tt.manipulator != nil {
+ tt.manipulator(t, r)
}
tt.wantErr(t, r.FailRequest(tt.args.err), "FailRequest() error = %v", tt.args.err)
})
}
}
-func Test_requestTransaction_String(t1 *testing.T) {
+func Test_requestTransaction_String(t *testing.T) {
type fields struct {
- parent *requestTransactionManager
- transactionId int32
- operation pool.Runnable
- completionFuture pool.CompletionFuture
+ parent *requestTransactionManager
+ transactionId int32
+ operation pool.Runnable
}
tests := []struct {
- name string
- fields fields
- want string
+ name string
+ fields fields
+ manipulator func(t *testing.T, transaction *requestTransaction)
+ want string
}{
{
name: "give a string",
+ manipulator: func(t *testing.T, transaction *requestTransaction) {
+ transaction.setCompletionFuture(nil)
+ },
want: `
╔═requestTransaction═════════╗
║╔═transactionId╗╔═completed╗║
@@ -164,15 +168,17 @@ func Test_requestTransaction_String(t1 *testing.T) {
},
}
for _, tt := range tests {
- t1.Run(tt.name, func(t1 *testing.T) {
- t := &requestTransaction{
- parent: tt.fields.parent,
- transactionId: tt.fields.transactionId,
- operation: tt.fields.operation,
- completionFuture: tt.fields.completionFuture,
- transactionLog: testutils.ProduceTestingLogger(t1),
+ t.Run(tt.name, func(t1 *testing.T) {
+ _t := &requestTransaction{
+ parent: tt.fields.parent,
+ transactionId: tt.fields.transactionId,
+ operation: tt.fields.operation,
+ transactionLog: testutils.ProduceTestingLogger(t1),
}
- if got := t.String(); got != tt.want {
+ if tt.manipulator != nil {
+ tt.manipulator(t, _t)
+ }
+ if got := _t.String(); got != tt.want {
t1.Errorf("String() = \n%v, want \n%v", got, tt.want)
}
})
@@ -181,12 +187,11 @@ func Test_requestTransaction_String(t1 *testing.T) {
func Test_requestTransaction_Submit(t1 *testing.T) {
type fields struct {
- parent *requestTransactionManager
- transactionId int32
- operation pool.Runnable
- completionFuture pool.CompletionFuture
- transactionLog zerolog.Logger
- completed bool
+ parent *requestTransactionManager
+ transactionId int32
+ operation pool.Runnable
+ transactionLog zerolog.Logger
+ completed bool
}
type args struct {
operation RequestTransactionRunnable
@@ -240,12 +245,11 @@ func Test_requestTransaction_Submit(t1 *testing.T) {
for _, tt := range tests {
t1.Run(tt.name, func(t1 *testing.T) {
t := &requestTransaction{
- parent: tt.fields.parent,
- transactionId: tt.fields.transactionId,
- operation: tt.fields.operation,
- completionFuture: tt.fields.completionFuture,
- transactionLog: tt.fields.transactionLog,
- completed: tt.fields.completed,
+ parent: tt.fields.parent,
+ transactionId: tt.fields.transactionId,
+ operation: tt.fields.operation,
+ transactionLog: tt.fields.transactionLog,
+ completed: tt.fields.completed,
}
t.Submit(tt.args.operation)
t.operation()
@@ -255,10 +259,9 @@ func Test_requestTransaction_Submit(t1 *testing.T) {
func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
type fields struct {
- parent *requestTransactionManager
- transactionId int32
- operation pool.Runnable
- completionFuture pool.CompletionFuture
+ parent *requestTransactionManager
+ transactionId int32
+ operation pool.Runnable
}
type args struct {
ctx context.Context
@@ -285,13 +288,12 @@ func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
return ctx
}(),
},
- mockSetup: func(t *testing.T, fields *fields, args *args) {
- completionFuture := NewMockCompletionFuture(t)
- expect := completionFuture.EXPECT()
- expect.AwaitCompletion(mock.Anything).Return(nil)
- fields.completionFuture = completionFuture
- },
manipulator: func(t *testing.T, transaction *requestTransaction) {
+ completionFutureMock := NewMockCompletionFuture(t)
+ expect := completionFutureMock.EXPECT()
+ expect.AwaitCompletion(mock.Anything).Return(nil)
+ var completionFuture pool.CompletionFuture = completionFutureMock
+ transaction.completionFuture.Store(&completionFuture)
go func() {
time.Sleep(100 * time.Millisecond)
r := transaction.parent
@@ -308,11 +310,13 @@ func Test_requestTransaction_AwaitCompletion(t1 *testing.T) {
tt.mockSetup(t1, &tt.fields, &tt.args)
}
t := &requestTransaction{
- parent: tt.fields.parent,
- transactionId: tt.fields.transactionId,
- operation: tt.fields.operation,
- completionFuture: tt.fields.completionFuture,
- transactionLog: testutils.ProduceTestingLogger(t1),
+ parent: tt.fields.parent,
+ transactionId: tt.fields.transactionId,
+ operation: tt.fields.operation,
+ transactionLog: testutils.ProduceTestingLogger(t1),
+ }
+ if tt.manipulator != nil {
+ tt.manipulator(t1, t)
}
if err := t.AwaitCompletion(tt.args.ctx); (err != nil) != tt.wantErr {
t1.Errorf("AwaitCompletion() error = %v, wantErr %v", err, tt.wantErr)
diff --git a/plc4go/spi/transactions/requestTransaction_plc4xgen.go b/plc4go/spi/transactions/requestTransaction_plc4xgen.go
index d6dd6afa75..733c59fde4 100644
--- a/plc4go/spi/transactions/requestTransaction_plc4xgen.go
+++ b/plc4go/spi/transactions/requestTransaction_plc4xgen.go
@@ -47,8 +47,9 @@ func (d *requestTransaction) SerializeWithWriteBuffer(ctx context.Context, write
return err
}
- if d.completionFuture != nil {
- if serializableField, ok := d.completionFuture.(utils.Serializable); ok {
+ if completionFutureLoaded := d.completionFuture.Load(); completionFutureLoaded != nil && *completionFutureLoaded != nil {
+ completionFuture := *completionFutureLoaded
+ if serializableField, ok := completionFuture.(utils.Serializable); ok {
if err := writeBuffer.PushContext("completionFuture"); err != nil {
return err
}
@@ -59,7 +60,7 @@ func (d *requestTransaction) SerializeWithWriteBuffer(ctx context.Context, write
return err
}
} else {
- stringValue := fmt.Sprintf("%v", d.completionFuture)
+ stringValue := fmt.Sprintf("%v", completionFuture)
if err := writeBuffer.WriteString("completionFuture", uint32(len(stringValue)*8), "UTF-8", stringValue); err != nil {
return err
}