You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2023/06/15 19:31:49 UTC
[beam] branch master updated: Move to a conditionVariable for messages+state stream + test. (#27060)
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 644f5399dce Move to a conditionVariable for messages+state stream + test. (#27060)
644f5399dce is described below
commit 644f5399dce646ba9b6b0146b44ff7778fc0a8c1
Author: Robert Burke <lo...@users.noreply.github.com>
AuthorDate: Thu Jun 15 12:31:41 2023 -0700
Move to a conditionVariable for messages+state stream + test. (#27060)
Co-authored-by: lostluck <13...@users.noreply.github.com>
---
.../beam/runners/prism/internal/jobservices/job.go | 45 +++-
.../prism/internal/jobservices/management.go | 83 +++++---
.../prism/internal/jobservices/management_test.go | 230 ++++++++++++++++++++-
3 files changed, 315 insertions(+), 43 deletions(-)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index 5b8e786ac6f..4ac37c5db59 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -29,6 +29,7 @@ import (
"fmt"
"sort"
"strings"
+ "sync"
"sync/atomic"
fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
@@ -70,9 +71,12 @@ type Job struct {
options *structpb.Struct
// Management side concerns.
- msgChan chan string
- state atomic.Value // jobpb.JobState_Enum
- stateChan chan jobpb.JobState_Enum
+ streamCond *sync.Cond
+ // TODO, consider unifying messages and state to a single ordered buffer.
+ minMsg, maxMsg int // logical indices into the message slice
+ msgs []string
+ stateIdx int
+ state atomic.Value // jobpb.JobState_Enum
// Context used to terminate this job.
RootCtx context.Context
@@ -107,25 +111,50 @@ func (j *Job) LogValue() slog.Value {
}
func (j *Job) SendMsg(msg string) {
- j.msgChan <- msg
+ j.streamCond.L.Lock()
+ defer j.streamCond.L.Unlock()
+ j.maxMsg++
+ // Trim so we never have more than 120 messages, keeping the last 100 for sure
+ // but amortize it so that messages are only trimmed every 20 messages beyond
+ // that.
+ // TODO, make this configurable
+ const buffered, trigger = 100, 20
+ if len(j.msgs) > buffered+trigger {
+ copy(j.msgs[0:], j.msgs[trigger:])
+ for k, n := len(j.msgs)-trigger, len(j.msgs); k < n; k++ {
+ j.msgs[k] = ""
+ }
+ j.msgs = j.msgs[:len(j.msgs)-trigger]
+ j.minMsg += trigger // increase the "min" message higher as a result.
+ }
+ j.msgs = append(j.msgs, msg)
+ j.streamCond.Broadcast()
+}
+
+func (j *Job) sendState(state jobpb.JobState_Enum) {
+ j.streamCond.L.Lock()
+ defer j.streamCond.L.Unlock()
+ j.stateIdx++
+ j.state.Store(state)
+ j.streamCond.Broadcast()
}
// Start indicates that the job is preparing to execute.
func (j *Job) Start() {
- j.stateChan <- jobpb.JobState_STARTING
+ j.sendState(jobpb.JobState_STARTING)
}
// Running indicates that the job is executing.
func (j *Job) Running() {
- j.stateChan <- jobpb.JobState_RUNNING
+ j.sendState(jobpb.JobState_RUNNING)
}
// Done indicates that the job completed successfully.
func (j *Job) Done() {
- j.stateChan <- jobpb.JobState_DONE
+ j.sendState(jobpb.JobState_DONE)
}
// Failed indicates that the job completed unsuccessfully.
func (j *Job) Failed() {
- j.stateChan <- jobpb.JobState_FAILED
+ j.sendState(jobpb.JobState_FAILED)
}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 6e774332f0e..cecd95536ae 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -18,6 +18,7 @@ package jobservices
import (
"context"
"fmt"
+ "sync"
"sync/atomic"
jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
@@ -69,20 +70,17 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo
// Since jobs execute in the background, they should not be tied to a request's context.
rootCtx, cancelFn := context.WithCancel(context.Background())
job := &Job{
- key: s.nextId(),
- Pipeline: req.GetPipeline(),
- jobName: req.GetJobName(),
- options: req.GetPipelineOptions(),
-
- msgChan: make(chan string, 100),
- stateChan: make(chan jobpb.JobState_Enum, 1),
- RootCtx: rootCtx,
- CancelFn: cancelFn,
+ key: s.nextId(),
+ Pipeline: req.GetPipeline(),
+ jobName: req.GetJobName(),
+ options: req.GetPipelineOptions(),
+ streamCond: sync.NewCond(&sync.Mutex{}),
+ RootCtx: rootCtx,
+ CancelFn: cancelFn,
}
// Queue initial state of the job.
job.state.Store(jobpb.JobState_STOPPED)
- job.stateChan <- job.state.Load().(jobpb.JobState_Enum)
if err := isSupported(job.Pipeline.GetRequirements()); err != nil {
slog.Error("unable to run job", slog.String("error", err.Error()), slog.String("jobname", req.GetJobName()))
@@ -165,15 +163,45 @@ func (s *Server) Run(ctx context.Context, req *jobpb.RunJobRequest) (*jobpb.RunJ
}, nil
}
-// GetMessageStream subscribes to a stream of state changes and messages from the job
+// GetMessageStream subscribes to a stream of state changes and messages from the job. If throughput
+// is high, this may cause losses of messages.
func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.JobService_GetMessageStreamServer) error {
s.mu.Lock()
- job := s.jobs[req.GetJobId()]
+ job, ok := s.jobs[req.GetJobId()]
s.mu.Unlock()
+ if !ok {
+ return fmt.Errorf("job with id %v not found", req.GetJobId())
+ }
+ job.streamCond.L.Lock()
+ defer job.streamCond.L.Unlock()
+ curMsg := job.minMsg
+ curState := job.stateIdx
+
+ stream.Context()
+
+ state := job.state.Load().(jobpb.JobState_Enum)
for {
- select {
- case msg := <-job.msgChan:
+ for (curMsg >= job.maxMsg || len(job.msgs) == 0) && curState > job.stateIdx {
+ switch state {
+ case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED:
+ // Reached terminal state.
+ return nil
+ }
+ job.streamCond.Wait()
+ select { // Quit out if the external connection is done.
+ case <-stream.Context().Done():
+ return stream.Context().Err()
+ default:
+ }
+ }
+
+ if curMsg < job.minMsg {
+ // TODO report missed messages for this stream.
+ curMsg = job.minMsg
+ }
+ for curMsg < job.maxMsg && len(job.msgs) > 0 {
+ msg := job.msgs[curMsg-job.minMsg]
stream.Send(&jobpb.JobMessagesResponse{
Response: &jobpb.JobMessagesResponse_MessageResponse{
MessageResponse: &jobpb.JobMessage{
@@ -182,20 +210,12 @@ func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.Jo
},
},
})
-
- case state, ok := <-job.stateChan:
- // TODO: Don't block job execution if WaitForCompletion isn't being run.
- // The state channel means the job may only execute if something is observing
- // the message stream, as the send on the state or message channel may block
- // once full.
- // Not a problem for tests or short lived batch, but would be hazardous for
- // asynchronous jobs.
-
- // Channel is closed, so the job must be done.
- if !ok {
- state = jobpb.JobState_DONE
- }
- job.state.Store(state)
+ curMsg++
+ }
+ if curState <= job.stateIdx {
+ state = job.state.Load().(jobpb.JobState_Enum)
+ curState = job.stateIdx + 1
+ job.streamCond.L.Unlock()
stream.Send(&jobpb.JobMessagesResponse{
Response: &jobpb.JobMessagesResponse_StateResponse{
StateResponse: &jobpb.JobStateEvent{
@@ -203,14 +223,9 @@ func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream jobpb.Jo
},
},
})
- switch state {
- case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED:
- // Reached terminal state.
- return nil
- }
+ job.streamCond.L.Lock()
}
}
-
}
// GetJobMetrics Fetch metrics for a given job.
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
index b7861276702..5813e6ef73e 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
@@ -17,6 +17,10 @@ package jobservices
import (
"context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
"sync"
"testing"
@@ -27,6 +31,9 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/testing/protocmp"
)
@@ -218,4 +225,225 @@ func TestServer(t *testing.T) {
}
}
-// TODO impelment message stream test, once message/State implementation is sync.Cond based.
+func TestGetMessageStream(t *testing.T) {
+ wantName := "testJob"
+ wantPipeline := &pipepb.Pipeline{
+ Requirements: []string{urns.RequirementSplittableDoFn},
+ }
+ var called sync.WaitGroup
+ called.Add(1)
+ ctx, _, clientConn := serveTestServer(t, func(j *Job) {
+ j.Start()
+ j.SendMsg("job starting")
+ j.Running()
+ j.SendMsg("job running")
+ j.SendMsg("job finished")
+ j.Done()
+ j.SendMsg("job done")
+ called.Done()
+ })
+ jobCli := jobpb.NewJobServiceClient(clientConn)
+
+ // PreJob submission
+ msgStream, err := jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+ JobId: "job-001",
+ })
+ if err != nil {
+ t.Errorf("GetMessageStream: wanted successful connection, got %v", err)
+ }
+ _, err = msgStream.Recv()
+ if err == nil {
+ t.Error("wanted error on non-existent job, but didn't happen.")
+ }
+
+ prepResp, err := jobCli.Prepare(ctx, &jobpb.PrepareJobRequest{
+ Pipeline: wantPipeline,
+ JobName: wantName,
+ })
+ if err != nil {
+ t.Fatalf("Prepare(%v) = %v, want nil", wantName, err)
+ }
+
+ // Post Job submission
+ msgStream, err = jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+ JobId: "job-001",
+ })
+ if err != nil {
+ t.Errorf("GetMessageStream: wanted successful connection, got %v", err)
+ }
+ stateResponse, err := msgStream.Recv()
+ if err != nil {
+ t.Errorf("GetMessageStream().Recv() = %v, want nil", err)
+ }
+ if got, want := stateResponse.GetStateResponse().GetState(), jobpb.JobState_STOPPED; got != want {
+ t.Errorf("GetMessageStream().Recv() = %v, want %v", got, want)
+ }
+
+ _, err = jobCli.Run(ctx, &jobpb.RunJobRequest{
+ PreparationId: prepResp.GetPreparationId(),
+ })
+ if err != nil {
+ t.Fatalf("Run(%v) = %v, want nil", wantName, err)
+ }
+
+ called.Wait() // Wait for the job to terminate.
+
+ receivedDone := false
+ var msgCount int
+ for {
+ // Continue with the same message stream.
+ resp, err := msgStream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break // successful message stream completion
+ }
+ t.Errorf("GetMessageStream().Recv() = %v, want nil", err)
+ }
+ switch {
+
+ case resp.GetMessageResponse() != nil:
+ msgCount++
+ case resp.GetStateResponse() != nil:
+ if resp.GetStateResponse().GetState() == jobpb.JobState_DONE {
+ receivedDone = true
+ }
+ }
+ }
+ if got, want := msgCount, 4; got != want {
+ t.Errorf("GetMessageStream() didn't correct number of messages, got %v, want %v", got, want)
+ }
+ if !receivedDone {
+ t.Error("GetMessageStream() didn't return job done state")
+ }
+ msgStream.CloseSend()
+
+ // Create a new message stream, we should still get a tail of messages (in this case, all of them)
+ // And the final state.
+ msgStream, err = jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+ JobId: "job-001",
+ })
+ if err != nil {
+ t.Errorf("GetMessageStream: wanted successful connection, got %v", err)
+ }
+
+ receivedDone = false
+ msgCount = 0
+ for {
+ // Continue with the same message stream.
+ resp, err := msgStream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break // successful message stream completion
+ }
+ t.Errorf("GetMessageStream().Recv() = %v, want nil", err)
+ }
+ switch {
+
+ case resp.GetMessageResponse() != nil:
+ msgCount++
+ case resp.GetStateResponse() != nil:
+ if resp.GetStateResponse().GetState() == jobpb.JobState_DONE {
+ receivedDone = true
+ }
+ }
+ }
+ if got, want := msgCount, 4; got != want {
+ t.Errorf("GetMessageStream() didn't correct number of messages, got %v, want %v", got, want)
+ }
+ if !receivedDone {
+ t.Error("GetMessageStream() didn't return job done state")
+ }
+}
+
+func TestGetMessageStream_BufferCycling(t *testing.T) {
+ wantName := "testJob"
+ wantPipeline := &pipepb.Pipeline{
+ Requirements: []string{urns.RequirementSplittableDoFn},
+ }
+ var called sync.WaitGroup
+ called.Add(1)
+ ctx, _, clientConn := serveTestServer(t, func(j *Job) {
+ j.Start()
+ // Using an offset from the trigger amount to ensure expected
+ // behavior (we can sometimes get more than the last 100 messages).
+ for i := 0; i < 512; i++ {
+ j.SendMsg(fmt.Sprintf("message number %v", i))
+ }
+ j.Done()
+ called.Done()
+ })
+ jobCli := jobpb.NewJobServiceClient(clientConn)
+
+ prepResp, err := jobCli.Prepare(ctx, &jobpb.PrepareJobRequest{
+ Pipeline: wantPipeline,
+ JobName: wantName,
+ })
+ if err != nil {
+ t.Fatalf("Prepare(%v) = %v, want nil", wantName, err)
+ }
+ _, err = jobCli.Run(ctx, &jobpb.RunJobRequest{
+ PreparationId: prepResp.GetPreparationId(),
+ })
+ if err != nil {
+ t.Fatalf("Run(%v) = %v, want nil", wantName, err)
+ }
+
+ called.Wait() // Wait for the job to terminate.
+
+ // Create a new message stream, we should still get a tail of messages (in this case, all of them)
+ // And the final state.
+ msgStream, err := jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+ JobId: "job-001",
+ })
+ if err != nil {
+ t.Errorf("GetMessageStream: wanted successful connection, got %v", err)
+ }
+
+ receivedDone := false
+ var msgCount int
+ for {
+ // Continue with the same message stream.
+ resp, err := msgStream.Recv()
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ break // successful message stream completion
+ }
+ t.Errorf("GetMessageStream().Recv() = %v, want nil", err)
+ }
+ switch {
+ case resp.GetMessageResponse() != nil:
+ msgCount++
+ case resp.GetStateResponse() != nil:
+ if resp.GetStateResponse().GetState() == jobpb.JobState_DONE {
+ receivedDone = true
+ }
+ }
+ }
+ if got, want := msgCount, 112; got != want {
+ t.Errorf("GetMessageStream() didn't correct number of messages, got %v, want %v", got, want)
+ }
+ if !receivedDone {
+ t.Error("GetMessageStream() didn't return job done state")
+ }
+
+}
+
+func serveTestServer(t *testing.T, execute func(j *Job)) (context.Context, *Server, *grpc.ClientConn) {
+ t.Helper()
+ ctx, cancelFn := context.WithCancel(context.Background())
+ t.Cleanup(cancelFn)
+
+ s := NewServer(0, execute)
+ lis := bufconn.Listen(1024 * 64)
+ s.lis = lis
+ t.Cleanup(func() { s.Stop() })
+ go s.Serve()
+
+ clientConn, err := grpc.DialContext(ctx, "", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
+ return lis.DialContext(ctx)
+ }), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock())
+ if err != nil {
+ t.Fatal("couldn't create bufconn grpc connection:", err)
+ }
+ return ctx, s, clientConn
+}