You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lo...@apache.org on 2023/02/20 23:19:34 UTC
[beam] 01/02: [prism] add in execution layer
This is an automated email from the ASF dual-hosted git repository.
lostluck pushed a commit to branch prism-execute
in repository https://gitbox.apache.org/repos/asf/beam.git
commit 7c60f9bd6389e491aec455425ad4e2d0a5426cb9
Author: Robert Burke <ro...@frantil.com>
AuthorDate: Sun Feb 19 14:52:48 2023 -0800
[prism] add in execution layer
---
sdks/go/pkg/beam/runners/prism/internal/execute.go | 644 ++++++++++++++++++++-
.../beam/runners/prism/internal/execute_test.go | 417 +++++++++++++
.../beam/runners/prism/internal/separate_test.go | 593 +++++++++++++++++++
sdks/go/pkg/beam/runners/prism/prism.go | 48 ++
4 files changed, 1701 insertions(+), 1 deletion(-)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 7c979ebf730..9c74102de8b 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -16,15 +16,459 @@
package internal
import (
+ "bytes"
+ "context"
+ "fmt"
+ "io"
+ "sort"
+
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+ fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker"
+ "golang.org/x/exp/maps"
+ "golang.org/x/exp/slog"
+ "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ "google.golang.org/protobuf/proto"
)
+func portFor(wInCid string, wk *worker.W) []byte {
+ sourcePort := &fnpb.RemoteGrpcPort{
+ CoderId: wInCid,
+ ApiServiceDescriptor: &pipepb.ApiServiceDescriptor{
+ Url: wk.Endpoint(),
+ },
+ }
+ sourcePortBytes, err := proto.Marshal(sourcePort)
+ if err != nil {
+ slog.Error("bad port", err, slog.String("endpoint", sourcePort.ApiServiceDescriptor.GetUrl()))
+ }
+ return sourcePortBytes
+}
+
+// collateByWindows takes the data and collates them into window keyed maps.
+// Uses generics to consolidate the repetitive window loops.
+func collateByWindows[T any](data [][]byte, watermark mtime.Time, wDec exec.WindowDecoder, wEnc exec.WindowEncoder, ed func(io.Reader) T, join func(T, T) T) map[typex.Window]T {
+ windowed := map[typex.Window]T{}
+ for _, datum := range data {
+ inBuf := bytes.NewBuffer(datum)
+ for {
+ ws, _, _, err := exec.DecodeWindowedValueHeader(wDec, inBuf)
+ if err == io.EOF {
+ break
+ }
+ // Get the element out, and window them properly.
+ e := ed(inBuf)
+ for _, w := range ws {
+ // if w.MaxTimestamp() > watermark {
+ // var t T
+ // slog.Debug(fmt.Sprintf("collateByWindows[%T]: window not yet closed, skipping %v > %v", t, w.MaxTimestamp(), watermark))
+ // continue
+ // }
+ windowed[w] = join(windowed[w], e)
+ }
+ }
+ }
+ return windowed
+}
+
// stage represents a fused subgraph.
-// temporary implementation to break up PRs.
+//
+// TODO: do we guarantee that they are all
+// the same environment at this point, or
+// should that be handled later?
type stage struct {
+ ID string
transforms []string
+
+ envID string
+ exe transformExecuter
+ outputCount int
+ inputTransformID string
+ mainInputPCol string
+ inputInfo engine.PColInfo
+ desc *fnpb.ProcessBundleDescriptor
+ sides []string
+ prepareSides func(b *worker.B, tid string, watermark mtime.Time)
+
+ SinkToPCollection map[string]string
+ OutputsToCoders map[string]engine.PColInfo
+}
+
+func (s *stage) Execute(j *jobservices.Job, wk *worker.W, comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) {
+ tid := s.transforms[0]
+ slog.Debug("Execute: starting bundle", "bundle", rb, slog.String("tid", tid))
+
+ var b *worker.B
+ var send bool
+ inputData := em.InputForBundle(rb, s.inputInfo)
+ switch s.envID {
+ case "": // Runner Transforms
+ // Runner transforms are processed immeadiately.
+ b = s.exe.ExecuteTransform(tid, comps.GetTransforms()[tid], comps, rb.Watermark, inputData)
+ b.InstID = rb.BundleID
+ slog.Debug("Execute: runner transform", "bundle", rb, slog.String("tid", tid))
+ case wk.ID:
+ send = true
+ b = &worker.B{
+ PBDID: s.ID,
+ InstID: rb.BundleID,
+
+ InputTransformID: s.inputTransformID,
+
+ // TODO Here's where we can split data for processing in multiple bundles.
+ InputData: inputData,
+
+ SinkToPCollection: s.SinkToPCollection,
+ OutputCount: s.outputCount,
+ }
+ b.Init()
+
+ s.prepareSides(b, s.transforms[0], rb.Watermark)
+ default:
+ err := fmt.Errorf("unknown environment[%v]", s.envID)
+ slog.Error("Execute", err)
+ panic(err)
+ }
+
+ if send {
+ slog.Debug("Execute: processing", "bundle", rb)
+ b.ProcessOn(wk) // Blocks until finished.
+ }
+ // Tentative Data is ready, commit it to the main datastore.
+ slog.Debug("Execute: commiting data", "bundle", rb, slog.Any("outputsWithData", maps.Keys(b.OutputData.Raw)), slog.Any("outputs", maps.Keys(s.OutputsToCoders)))
+
+ resp := &fnpb.ProcessBundleResponse{}
+ if send {
+ resp = <-b.Resp
+ // Tally metrics immeadiately so they're available before
+ // pipeline termination.
+ j.ContributeMetrics(resp)
+ }
+ // TODO handle side input data properly.
+ wk.D.Commit(b.OutputData)
+ var residualData [][]byte
+ var minOutputWatermark map[string]mtime.Time
+ for _, rr := range resp.GetResidualRoots() {
+ ba := rr.GetApplication()
+ residualData = append(residualData, ba.GetElement())
+ if len(ba.GetElement()) == 0 {
+ slog.Log(slog.LevelError, "returned empty residual application", "bundle", rb)
+ panic("sdk returned empty residual application")
+ }
+ for col, wm := range ba.GetOutputWatermarks() {
+ if minOutputWatermark == nil {
+ minOutputWatermark = map[string]mtime.Time{}
+ }
+ cur, ok := minOutputWatermark[col]
+ if !ok {
+ cur = mtime.MaxTimestamp
+ }
+ minOutputWatermark[col] = mtime.Min(mtime.FromTime(wm.AsTime()), cur)
+ }
+ }
+ if l := len(residualData); l > 0 {
+ slog.Debug("returned empty residual application", "bundle", rb, slog.Int("numResiduals", l), slog.String("pcollection", s.mainInputPCol))
+ }
+ em.PersistBundle(rb, s.OutputsToCoders, b.OutputData, s.inputInfo, residualData, minOutputWatermark)
+ b.OutputData = engine.TentativeData{} // Clear the data.
+}
+
+func buildStage(s *stage, tid string, t *pipepb.PTransform, comps *pipepb.Components, wk *worker.W) {
+ s.inputTransformID = tid + "_source"
+
+ coders := map[string]*pipepb.Coder{}
+ transforms := map[string]*pipepb.PTransform{
+ tid: t, // The Transform to Execute!
+ }
+
+ sis, err := getSideInputs(t)
+ if err != nil {
+ slog.Error("buildStage: getSide Inputs", err, slog.String("transformID", tid))
+ panic(err)
+ }
+ var inputInfo engine.PColInfo
+ var sides []string
+ for local, global := range t.GetInputs() {
+ // This id is directly used for the source, but this also copies
+ // coders used by side inputs to the coders map for the bundle, so
+ // needs to be run for every ID.
+ wInCid := makeWindowedValueCoder(global, comps, coders)
+ _, ok := sis[local]
+ if ok {
+ sides = append(sides, global)
+ } else {
+ // this is the main input
+ transforms[s.inputTransformID] = sourceTransform(s.inputTransformID, portFor(wInCid, wk), global)
+ col := comps.GetPcollections()[global]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+ inputInfo = engine.PColInfo{
+ GlobalID: global,
+ WDec: wDec,
+ WEnc: wEnc,
+ EDec: ed,
+ }
+ }
+ // We need to process all inputs to ensure we have all input coders, so we must continue.
+ }
+
+ prepareSides, err := handleSideInputs(t, comps, coders, wk)
+ if err != nil {
+ slog.Error("buildStage: handleSideInputs", err, slog.String("transformID", tid))
+ panic(err)
+ }
+
+ // TODO: We need a new logical PCollection to represent the source
+ // so we can avoid double counting PCollection metrics later.
+ // But this also means replacing the ID for the input in the bundle.
+ sink2Col := map[string]string{}
+ col2Coders := map[string]engine.PColInfo{}
+ for local, global := range t.GetOutputs() {
+ wOutCid := makeWindowedValueCoder(global, comps, coders)
+ sinkID := tid + "_" + local
+ col := comps.GetPcollections()[global]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+ sink2Col[sinkID] = global
+ col2Coders[global] = engine.PColInfo{
+ GlobalID: global,
+ WDec: wDec,
+ WEnc: wEnc,
+ EDec: ed,
+ }
+ transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), global)
+ }
+
+ reconcileCoders(coders, comps.GetCoders())
+
+ desc := &fnpb.ProcessBundleDescriptor{
+ Id: s.ID,
+ Transforms: transforms,
+ WindowingStrategies: comps.GetWindowingStrategies(),
+ Pcollections: comps.GetPcollections(),
+ Coders: coders,
+ StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{
+ Url: wk.Endpoint(),
+ },
+ }
+
+ s.desc = desc
+ s.outputCount = len(t.Outputs)
+ s.prepareSides = prepareSides
+ s.sides = sides
+ s.SinkToPCollection = sink2Col
+ s.OutputsToCoders = col2Coders
+ s.mainInputPCol = inputInfo.GlobalID
+ s.inputInfo = inputInfo
+
+ wk.Descriptors[s.ID] = s.desc
+}
+
+func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) {
+ if t.GetSpec().GetUrn() != urns.TransformParDo {
+ return nil, nil
+ }
+ pardo := &pipepb.ParDoPayload{}
+ if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil {
+ return nil, fmt.Errorf("unable to decode ParDoPayload")
+ }
+ return pardo.GetSideInputs(), nil
+}
+
+// handleSideInputs ensures appropriate coders are available to the bundle, and prepares a function to stage the data.
+func handleSideInputs(t *pipepb.PTransform, comps *pipepb.Components, coders map[string]*pipepb.Coder, wk *worker.W) (func(b *worker.B, tid string, watermark mtime.Time), error) {
+ sis, err := getSideInputs(t)
+ if err != nil {
+ return nil, err
+ }
+ var prepSides []func(b *worker.B, tid string, watermark mtime.Time)
+
+ // Get WindowedValue Coders for the transform's input and output PCollections.
+ for local, global := range t.GetInputs() {
+ si, ok := sis[local]
+ if !ok {
+ continue // This is the main input.
+ }
+
+ // this is a side input
+ switch si.GetAccessPattern().GetUrn() {
+ case urns.SideInputIterable:
+ slog.Debug("urnSideInputIterable",
+ slog.String("sourceTransform", t.GetUniqueName()),
+ slog.String("local", local),
+ slog.String("global", global))
+ col := comps.GetPcollections()[global]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+ // May be of zero length, but that's OK. Side inputs can be empty.
+
+ global, local := global, local
+ prepSides = append(prepSides, func(b *worker.B, tid string, watermark mtime.Time) {
+ data := wk.D.GetAllData(global)
+
+ if b.IterableSideInputData == nil {
+ b.IterableSideInputData = map[string]map[string]map[typex.Window][][]byte{}
+ }
+ if _, ok := b.IterableSideInputData[tid]; !ok {
+ b.IterableSideInputData[tid] = map[string]map[typex.Window][][]byte{}
+ }
+ b.IterableSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc,
+ func(r io.Reader) [][]byte {
+ return [][]byte{ed(r)}
+ }, func(a, b [][]byte) [][]byte {
+ return append(a, b...)
+ })
+ })
+
+ case urns.SideInputMultiMap:
+ slog.Debug("urnSideInputMultiMap",
+ slog.String("sourceTransform", t.GetUniqueName()),
+ slog.String("local", local),
+ slog.String("global", global))
+ col := comps.GetPcollections()[global]
+
+ kvc := comps.GetCoders()[col.GetCoderId()]
+ if kvc.GetSpec().GetUrn() != urns.CoderKV {
+ return nil, fmt.Errorf("multimap side inputs needs KV coder, got %v", kvc.GetSpec().GetUrn())
+ }
+
+ kd := collectionPullDecoder(kvc.GetComponentCoderIds()[0], coders, comps)
+ vd := collectionPullDecoder(kvc.GetComponentCoderIds()[1], coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+
+ global, local := global, local
+ prepSides = append(prepSides, func(b *worker.B, tid string, watermark mtime.Time) {
+ // May be of zero length, but that's OK. Side inputs can be empty.
+ data := wk.D.GetAllData(global)
+ if b.MultiMapSideInputData == nil {
+ b.MultiMapSideInputData = map[string]map[string]map[typex.Window]map[string][][]byte{}
+ }
+ if _, ok := b.MultiMapSideInputData[tid]; !ok {
+ b.MultiMapSideInputData[tid] = map[string]map[typex.Window]map[string][][]byte{}
+ }
+ b.MultiMapSideInputData[tid][local] = collateByWindows(data, watermark, wDec, wEnc,
+ func(r io.Reader) map[string][][]byte {
+ kb := kd(r)
+ return map[string][][]byte{
+ string(kb): {vd(r)},
+ }
+ }, func(a, b map[string][][]byte) map[string][][]byte {
+ if len(a) == 0 {
+ return b
+ }
+ for k, vs := range b {
+ a[k] = append(a[k], vs...)
+ }
+ return a
+ })
+ })
+ default:
+ return nil, fmt.Errorf("local input %v (global %v) uses accesspattern %v", local, global, si.GetAccessPattern().GetUrn())
+ }
+ }
+ return func(b *worker.B, tid string, watermark mtime.Time) {
+ for _, prep := range prepSides {
+ prep(b, tid, watermark)
+ }
+ }, nil
+}
+
+func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, comps *pipepb.Components) func(io.Reader) []byte {
+ cID := lpUnknownCoders(coldCId, coders, comps.GetCoders())
+ return pullDecoder(coders[cID], coders)
+}
+
+func getWindowValueCoders(comps *pipepb.Components, col *pipepb.PCollection, coders map[string]*pipepb.Coder) (exec.WindowDecoder, exec.WindowEncoder) {
+ ws := comps.GetWindowingStrategies()[col.GetWindowingStrategyId()]
+ wcID := lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders())
+ return makeWindowCoders(coders[wcID])
+}
+
+func sourceTransform(parentID string, sourcePortBytes []byte, outPID string) *pipepb.PTransform {
+ source := &pipepb.PTransform{
+ UniqueName: parentID,
+ Spec: &pipepb.FunctionSpec{
+ Urn: urns.TransformSource,
+ Payload: sourcePortBytes,
+ },
+ Outputs: map[string]string{
+ "i0": outPID,
+ },
+ }
+ return source
+}
+
+func sinkTransform(sinkID string, sinkPortBytes []byte, inPID string) *pipepb.PTransform {
+ source := &pipepb.PTransform{
+ UniqueName: sinkID,
+ Spec: &pipepb.FunctionSpec{
+ Urn: urns.TransformSink,
+ Payload: sinkPortBytes,
+ },
+ Inputs: map[string]string{
+ "i0": inPID,
+ },
+ }
+ return source
+}
+
+func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *worker.W) {
+ conn, err := grpc.Dial(ep.GetEndpoint().GetUrl(), grpc.WithTransportCredentials(insecure.NewCredentials()))
+ if err != nil {
+ panic(fmt.Sprintf("unable to dial sdk worker %v: %v", ep.GetEndpoint().GetUrl(), err))
+ }
+ defer conn.Close()
+ pool := fnpb.NewBeamFnExternalWorkerPoolClient(conn)
+
+ endpoint := &pipepb.ApiServiceDescriptor{
+ Url: wk.Endpoint(),
+ }
+
+ pool.StartWorker(ctx, &fnpb.StartWorkerRequest{
+ WorkerId: wk.ID,
+ ControlEndpoint: endpoint,
+ LoggingEndpoint: endpoint,
+ ArtifactEndpoint: endpoint,
+ ProvisionEndpoint: endpoint,
+ Params: nil,
+ })
+
+ // Job processing happens here, but orchestrated by other goroutines
+ // This goroutine blocks until the context is cancelled, signalling
+ // that the pool runner should stop the worker.
+ <-ctx.Done()
+
+ // Previous context cancelled so we need a new one
+ // for this request.
+ pool.StopWorker(context.Background(), &fnpb.StopWorkerRequest{
+ WorkerId: wk.ID,
+ })
+}
+
+func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) {
+ // TODO fix broken abstraction.
+ // We're starting a worker pool here, because that's the loopback environment.
+ // It's sort of a mess, largely because of loopback, which has
+ // a different flow from a provisioned docker container.
+ e := j.Pipeline.GetComponents().GetEnvironments()[env]
+ switch e.GetUrn() {
+ case urns.EnvExternal:
+ ep := &pipepb.ExternalPayload{}
+ if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), ep); err != nil {
+ slog.Error("unmarshing environment payload", err, slog.String("envID", wk.ID))
+ }
+ externalEnvironment(ctx, ep, wk)
+ slog.Info("environment stopped", slog.String("envID", wk.String()), slog.String("job", j.String()))
+ default:
+ panic(fmt.Sprintf("environment %v with urn %v unimplemented", env, e.GetUrn()))
+ }
}
type transformExecuter interface {
@@ -32,3 +476,201 @@ type transformExecuter interface {
ExecuteWith(t *pipepb.PTransform) string
ExecuteTransform(tid string, t *pipepb.PTransform, comps *pipepb.Components, watermark mtime.Time, data [][]byte) *worker.B
}
+
+type processor struct {
+ transformExecuters map[string]transformExecuter
+}
+
+func getOnlyValue[K comparable, V any](in map[K]V) V {
+ if len(in) != 1 {
+ panic(fmt.Sprintf("expected single value map, had %v", len(in)))
+ }
+ for _, v := range in {
+ return v
+ }
+ panic("unreachable")
+}
+
+func executePipeline(ctx context.Context, wk *worker.W, j *jobservices.Job) {
+ pipeline := j.Pipeline
+ comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components)
+
+ // TODO, configure the preprocessor from pipeline options.
+ // Maybe change these returns to a single struct for convenience and further
+ // annotation?
+
+ handlers := []any{
+ Combine(CombineCharacteristic{EnableLifting: true}),
+ ParDo(ParDoCharacteristic{DisableSDF: true}),
+ Runner(RunnerCharacteristic{
+ SDKFlatten: false,
+ }),
+ }
+
+ proc := processor{
+ transformExecuters: map[string]transformExecuter{},
+ }
+
+ var preppers []transformPreparer
+ for _, h := range handlers {
+ if th, ok := h.(transformPreparer); ok {
+ preppers = append(preppers, th)
+ }
+ if th, ok := h.(transformExecuter); ok {
+ for _, urn := range th.ExecuteUrns() {
+ proc.transformExecuters[urn] = th
+ }
+ }
+ }
+
+ prepro := newPreprocessor(preppers)
+
+ topo := prepro.preProcessGraph(comps)
+ ts := comps.GetTransforms()
+
+ em := engine.NewElementManager(engine.Config{})
+
+ // This is where the Batch -> Streaming tension exists.
+ // We don't *pre* do this, and we need a different mechanism
+ // to sort out processing order.
+ stages := map[string]*stage{}
+ var impulses []string
+ for i, stage := range topo {
+ if len(stage.transforms) != 1 {
+ panic(fmt.Sprintf("unsupported stage[%d]: contains multiple transforms: %v; TODO: implement fusion", i, stage.transforms))
+ }
+ tid := stage.transforms[0]
+ t := ts[tid]
+ urn := t.GetSpec().GetUrn()
+ stage.exe = proc.transformExecuters[urn]
+
+ // Stopgap until everythinng's moved to handlers.
+ stage.envID = t.GetEnvironmentId()
+ if stage.exe != nil {
+ stage.envID = stage.exe.ExecuteWith(t)
+ }
+ stage.ID = wk.NextStage()
+
+ switch stage.envID {
+ case "": // Runner Transforms
+
+ var onlyOut string
+ for _, out := range t.GetOutputs() {
+ onlyOut = out
+ }
+ stage.OutputsToCoders = map[string]engine.PColInfo{}
+ coders := map[string]*pipepb.Coder{}
+ makeWindowedValueCoder(onlyOut, comps, coders)
+
+ col := comps.GetPcollections()[onlyOut]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+
+ stage.OutputsToCoders[onlyOut] = engine.PColInfo{
+ GlobalID: onlyOut,
+ WDec: wDec,
+ WEnc: wEnc,
+ EDec: ed,
+ }
+
+ // There's either 0, 1 or many inputs, but they should be all the same
+ // so break after the first one.
+ for _, global := range t.GetInputs() {
+ col := comps.GetPcollections()[global]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+ stage.inputInfo = engine.PColInfo{
+ GlobalID: global,
+ WDec: wDec,
+ WEnc: wEnc,
+ EDec: ed,
+ }
+ break
+ }
+
+ switch urn {
+ case urns.TransformGBK:
+ em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, nil, []string{getOnlyValue(t.GetOutputs())})
+ for _, global := range t.GetInputs() {
+ col := comps.GetPcollections()[global]
+ ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+ wDec, wEnc := getWindowValueCoders(comps, col, coders)
+ stage.inputInfo = engine.PColInfo{
+ GlobalID: global,
+ WDec: wDec,
+ WEnc: wEnc,
+ EDec: ed,
+ }
+ }
+ em.StageAggregates(stage.ID)
+ case urns.TransformImpulse:
+ impulses = append(impulses, stage.ID)
+ em.AddStage(stage.ID, nil, nil, []string{getOnlyValue(t.GetOutputs())})
+ case urns.TransformFlatten:
+ inputs := maps.Values(t.GetInputs())
+ sort.Strings(inputs)
+ em.AddStage(stage.ID, inputs, nil, []string{getOnlyValue(t.GetOutputs())})
+ }
+ stages[stage.ID] = stage
+ wk.Descriptors[stage.ID] = stage.desc
+ case wk.ID:
+ // Great! this is for this environment. // Broken abstraction.
+ buildStage(stage, tid, t, comps, wk)
+ stages[stage.ID] = stage
+ slog.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName())))
+ outputs := maps.Keys(stage.OutputsToCoders)
+ sort.Strings(outputs)
+ em.AddStage(stage.ID, []string{stage.mainInputPCol}, stage.sides, outputs)
+ default:
+ err := fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId())
+ slog.Error("Execute", err)
+ panic(err)
+ }
+ }
+
+ // Prime the initial impulses, since we now know what consumes them.
+ for _, id := range impulses {
+ em.Impulse(id)
+ }
+
+ // Execute stages here
+ for rb := range em.Bundles(ctx, wk.NextInst) {
+ s := stages[rb.StageID]
+ s.Execute(j, wk, comps, em, rb)
+ }
+ slog.Info("pipeline done!", slog.String("job", j.String()))
+}
+
+// RunPipeline starts the main thread fo executing this job.
+// It's analoguous to the manager side process for a distributed pipeline.
+// It will begin "workers"
+func RunPipeline(j *jobservices.Job) {
+ j.SendMsg("starting " + j.String())
+ j.Start()
+
+ // In a "proper" runner, we'd iterate through all the
+ // environments, and start up docker containers, but
+ // here, we only want and need the go one, operating
+ // in loopback mode.
+ env := "go"
+ wk := worker.New(env) // Cheating by having the worker id match the environment id.
+ go wk.Serve()
+
+ // When this function exits, we
+ defer func() {
+ j.CancelFn()
+ }()
+ go runEnvironment(j.RootCtx, j, env, wk)
+
+ j.SendMsg("running " + j.String())
+ j.Running()
+
+ executePipeline(j.RootCtx, wk, j)
+ j.SendMsg("pipeline completed " + j.String())
+
+ // Stop the worker.
+ wk.Stop()
+
+ j.SendMsg("terminating " + j.String())
+ j.Done()
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
new file mode 100644
index 00000000000..de7247486bb
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
@@ -0,0 +1,417 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "os"
+ "testing"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/filter"
+ "github.com/apache/beam/sdks/v2/go/test/integration/primitives"
+)
+
+func initRunner(t *testing.T) {
+ t.Helper()
+ if *jobopts.Endpoint == "" {
+ s := jobservices.NewServer(0, RunPipeline)
+ *jobopts.Endpoint = s.Endpoint()
+ go s.Serve()
+ t.Cleanup(func() {
+ *jobopts.Endpoint = ""
+ s.Stop()
+ })
+ }
+ if !jobopts.IsLoopback() {
+ *jobopts.EnvironmentType = "loopback"
+ }
+ // Since we force loopback, avoid cross-compilation.
+ f, err := os.CreateTemp("", "dummy")
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Cleanup(func() { os.Remove(f.Name()) })
+ *jobopts.WorkerBinary = f.Name()
+}
+
+func execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) {
+ return universal.Execute(ctx, p)
+}
+
+func executeWithT(ctx context.Context, t *testing.T, p *beam.Pipeline) (beam.PipelineResult, error) {
+ t.Log("startingTest - ", t.Name())
+ return execute(ctx, p)
+}
+
+func init() {
+ // Not actually being used, but explicitly registering
+ // will avoid accidentally using a different runner for
+ // the tests if I change things later.
+ beam.RegisterRunner("testlocal", execute)
+}
+
+func TestRunner_Pipelines(t *testing.T) {
+ initRunner(t)
+
+ tests := []struct {
+ name string
+ pipeline func(s beam.Scope)
+ metrics func(t *testing.T, pr beam.PipelineResult)
+ }{
+ {
+ name: "simple",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col := beam.ParDo(s, dofn1, imp)
+ beam.ParDo(s, &int64Check{
+ Name: "simple",
+ Want: []int{1, 2, 3},
+ }, col)
+ },
+ }, {
+ name: "sequence",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ beam.Seq(s, imp, dofn1, dofn2, dofn2, dofn2, &int64Check{Name: "sequence", Want: []int{4, 5, 6}})
+ },
+ }, {
+ name: "gbk",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col := beam.ParDo(s, dofnKV, imp)
+ gbk := beam.GroupByKey(s, col)
+ beam.Seq(s, gbk, dofnGBK, &int64Check{Name: "gbk", Want: []int{9, 12}})
+ },
+ }, {
+ name: "gbk2",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col := beam.ParDo(s, dofnKV2, imp)
+ gbk := beam.GroupByKey(s, col)
+ beam.Seq(s, gbk, dofnGBK2, &stringCheck{Name: "gbk2", Want: []string{"aaa", "bbb"}})
+ },
+ }, {
+ name: "gbk3",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col := beam.ParDo(s, dofnKV3, imp)
+ gbk := beam.GroupByKey(s, col)
+ beam.Seq(s, gbk, dofnGBK3, &stringCheck{Name: "gbk3", Want: []string{"{a 1}: {a 1}"}})
+ },
+ }, {
+ name: "sink_nooutputs",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ beam.ParDo0(s, dofnSink, imp)
+ },
+ metrics: func(t *testing.T, pr beam.PipelineResult) {
+ qr := pr.Metrics().Query(func(sr metrics.SingleResult) bool {
+ return sr.Name() == "sunk"
+ })
+ if got, want := qr.Counters()[0].Committed, int64(73); got != want {
+ t.Errorf("pr.Metrics.Query(Name = \"sunk\")).Committed = %v, want %v", got, want)
+ }
+ },
+ }, {
+ name: "fork_impulse",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofn1, imp)
+ col2 := beam.ParDo(s, dofn1, imp)
+ beam.ParDo(s, &int64Check{
+ Name: "fork check1",
+ Want: []int{1, 2, 3},
+ }, col1)
+ beam.ParDo(s, &int64Check{
+ Name: "fork check2",
+ Want: []int{1, 2, 3},
+ }, col2)
+ },
+ }, {
+ name: "fork_postDoFn",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col := beam.ParDo(s, dofn1, imp)
+ beam.ParDo(s, &int64Check{
+ Name: "fork check1",
+ Want: []int{1, 2, 3},
+ }, col)
+ beam.ParDo(s, &int64Check{
+ Name: "fork check2",
+ Want: []int{1, 2, 3},
+ }, col)
+ },
+ }, {
+ name: "fork_multipleOutputs1",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1, col2, col3, col4, col5 := beam.ParDo5(s, dofn1x5, imp)
+ beam.ParDo(s, &int64Check{
+ Name: "col1",
+ Want: []int{1, 6},
+ }, col1)
+ beam.ParDo(s, &int64Check{
+ Name: "col2",
+ Want: []int{2, 7},
+ }, col2)
+ beam.ParDo(s, &int64Check{
+ Name: "col3",
+ Want: []int{3, 8},
+ }, col3)
+ beam.ParDo(s, &int64Check{
+ Name: "col4",
+ Want: []int{4, 9},
+ }, col4)
+ beam.ParDo(s, &int64Check{
+ Name: "col5",
+ Want: []int{5, 10},
+ }, col5)
+ },
+ }, {
+ name: "fork_multipleOutputs2",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1, col2, col3, col4, col5 := beam.ParDo5(s, dofn1x5, imp)
+ beam.ParDo(s, &int64Check{
+ Name: "col1",
+ Want: []int{1, 6},
+ }, col1)
+ beam.ParDo(s, &int64Check{
+ Name: "col2",
+ Want: []int{2, 7},
+ }, col2)
+ beam.ParDo(s, &int64Check{
+ Name: "col3",
+ Want: []int{3, 8},
+ }, col3)
+ beam.ParDo(s, &int64Check{
+ Name: "col4",
+ Want: []int{4, 9},
+ }, col4)
+ beam.ParDo(s, &int64Check{
+ Name: "col5",
+ Want: []int{5, 10},
+ }, col5)
+ },
+ }, {
+ name: "flatten",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofn1, imp)
+ col2 := beam.ParDo(s, dofn1, imp)
+ flat := beam.Flatten(s, col1, col2)
+ beam.ParDo(s, &int64Check{
+ Name: "flatten check",
+ Want: []int{1, 1, 2, 2, 3, 3},
+ }, flat)
+ },
+ }, {
+ name: "sideinput_iterable_oneimpulse",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofn1, imp)
+ sum := beam.ParDo(s, dofn2x1, imp, beam.SideInput{Input: col1})
+ beam.ParDo(s, &int64Check{
+ Name: "iter sideinput check",
+ Want: []int{6},
+ }, sum)
+ },
+ }, {
+ name: "sideinput_iterable_twoimpulse",
+ pipeline: func(s beam.Scope) {
+ imp1 := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofn1, imp1)
+ imp2 := beam.Impulse(s)
+ sum := beam.ParDo(s, dofn2x1, imp2, beam.SideInput{Input: col1})
+ beam.ParDo(s, &int64Check{
+ Name: "iter sideinput check",
+ Want: []int{6},
+ }, sum)
+ },
+ }, {
+ name: "sideinput_iterableKV",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofnKV, imp)
+ keys, sum := beam.ParDo2(s, dofn2x2KV, imp, beam.SideInput{Input: col1})
+ beam.ParDo(s, &stringCheck{
+ Name: "iterKV sideinput check K",
+ Want: []string{"a", "a", "a", "b", "b", "b"},
+ }, keys)
+ beam.ParDo(s, &int64Check{
+ Name: "iterKV sideinput check V",
+ Want: []int{21},
+ }, sum)
+ },
+ }, {
+ name: "sideinput_iterableKV",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofnKV, imp)
+ keys, sum := beam.ParDo2(s, dofn2x2KV, imp, beam.SideInput{Input: col1})
+ beam.ParDo(s, &stringCheck{
+ Name: "iterKV sideinput check K",
+ Want: []string{"a", "a", "a", "b", "b", "b"},
+ }, keys)
+ beam.ParDo(s, &int64Check{
+ Name: "iterKV sideinput check V",
+ Want: []int{21},
+ }, sum)
+ },
+ }, {
+ name: "sideinput_multimap",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofnKV, imp)
+ keys := filter.Distinct(s, beam.DropValue(s, col1))
+ ks, sum := beam.ParDo2(s, dofnMultiMap, keys, beam.SideInput{Input: col1})
+ beam.ParDo(s, &stringCheck{
+ Name: "multiMap sideinput check K",
+ Want: []string{"a", "b"},
+ }, ks)
+ beam.ParDo(s, &int64Check{
+ Name: "multiMap sideinput check V",
+ Want: []int{9, 12},
+ }, sum)
+ },
+ }, {
+ // Ensures topological sort is correct.
+ name: "sideinput_2iterable",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col0 := beam.ParDo(s, dofn1, imp)
+ col1 := beam.ParDo(s, dofn1, imp)
+ col2 := beam.ParDo(s, dofn2, col1)
+ sum := beam.ParDo(s, dofn3x1, col0, beam.SideInput{Input: col1}, beam.SideInput{Input: col2})
+ beam.ParDo(s, &int64Check{
+ Name: "iter sideinput check",
+ Want: []int{16, 17, 18},
+ }, sum)
+ },
+ }, {
+ name: "combine_perkey",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ in := beam.ParDo(s, dofn1kv, imp)
+ keyedsum := beam.CombinePerKey(s, combineIntSum, in)
+ sum := beam.DropKey(s, keyedsum)
+ beam.ParDo(s, &int64Check{
+ Name: "combine",
+ Want: []int{6},
+ }, sum)
+ },
+ }, {
+ name: "combine_global",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ in := beam.ParDo(s, dofn1, imp)
+ sum := beam.Combine(s, combineIntSum, in)
+ beam.ParDo(s, &int64Check{
+ Name: "combine",
+ Want: []int{6},
+ }, sum)
+ },
+ }, {
+ name: "sdf_single_split",
+ pipeline: func(s beam.Scope) {
+ configs := beam.Create(s, SourceConfig{NumElements: 10, InitialSplits: 1})
+ in := beam.ParDo(s, &intRangeFn{}, configs)
+ beam.ParDo(s, &int64Check{
+ Name: "sdf_single",
+ Want: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
+ }, in)
+ },
+ }, {
+ name: "WindowedSideInputs",
+ pipeline: primitives.ValidateWindowedSideInputs,
+ }, {
+ name: "WindowSums_GBK",
+ pipeline: primitives.WindowSums_GBK,
+ }, {
+ name: "WindowSums_Lifted",
+ pipeline: primitives.WindowSums_Lifted,
+ }, {
+ name: "ProcessContinuations_globalCombine",
+ pipeline: func(s beam.Scope) {
+ out := beam.ParDo(s, &selfCheckpointingDoFn{}, beam.Impulse(s))
+ passert.Count(s, out, "num ints", 10)
+ },
+ }, {
+ name: "flatten_to_sideInput",
+ pipeline: func(s beam.Scope) {
+ imp := beam.Impulse(s)
+ col1 := beam.ParDo(s, dofn1, imp)
+ col2 := beam.ParDo(s, dofn1, imp)
+ flat := beam.Flatten(s, col1, col2)
+ beam.ParDo(s, &int64Check{
+ Name: "flatten check",
+ Want: []int{1, 1, 2, 2, 3, 3},
+ }, flat)
+ passert.NonEmpty(s, flat)
+ },
+ },
+ }
+ // TODO: Explicit DoFn Failure case.
+ // TODO: Session windows, where some are not merged.
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ p, s := beam.NewPipelineWithRoot()
+ test.pipeline(s)
+ pr, err := executeWithT(context.Background(), t, p)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if test.metrics != nil {
+ test.metrics(t, pr)
+ }
+ })
+ }
+}
+
+func TestRunner_Metrics(t *testing.T) {
+ initRunner(t)
+ t.Run("counter", func(t *testing.T) {
+ p, s := beam.NewPipelineWithRoot()
+ imp := beam.Impulse(s)
+ beam.ParDo(s, dofn1Counter, imp)
+ pr, err := executeWithT(context.Background(), t, p)
+ if err != nil {
+ t.Fatal(err)
+ }
+ qr := pr.Metrics().Query(func(sr metrics.SingleResult) bool {
+ return sr.Name() == "count"
+ })
+ if got, want := qr.Counters()[0].Committed, int64(1); got != want {
+ t.Errorf("pr.Metrics.Query(Name = \"count\")).Committed = %v, want %v", got, want)
+ }
+ })
+}
+
+// TODO: PCollection metrics tests, in particular for element counts, in multi transform pipelines
+// There's a doubling bug since we re-use the same pcollection IDs for the source & sink, and
+// don't do any re-writing.
+
+func TestMain(m *testing.M) {
+ ptest.MainWithDefault(m, "testlocal")
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/separate_test.go b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go
new file mode 100644
index 00000000000..edfe3736503
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/separate_test.go
@@ -0,0 +1,593 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "net/http"
+ "net/rpc"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/io/rtrackers/offsetrange"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/register"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/passert"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/transforms/stats"
+ "golang.org/x/exp/slog"
+)
+
+// separate_test.go is retains structures and tests to ensure the runner can
+// perform separation, and terminate checkpoints.
+
+// Global variable, so only one is registered with the OS.
+var ws = &Watchers{}
+
+// TestSeparation validates that the runner is able to split
+// elements in time and space. Beam has a few mechanisms to
+// do this.
+//
+// First is channel splits, where a slowly processing
+// bundle might have it's remaining buffered elements truncated
+// so they can be processed by a another bundle,
+// possibly simultaneously.
+//
+// Second is sub element splitting, where a single element
+// in an SDF might be split into smaller restrictions.
+//
+// Third with Checkpointing or ProcessContinuations,
+// a User DoFn may decide to defer processing of an element
+// until later, permitting a bundle to terminate earlier,
+// delaying processing.
+//
+// All these may be tested locally or in process with a small
+// server the DoFns can connect to. This can then indicate which
+// elements, or positions are considered "sentinels".
+//
+// When a sentinel is to be processed, instead the DoFn blocks.
+// The goal for Splitting tests is to succeed only when all
+// sentinels are blocking waiting to be processed.
+// This indicates the runner has "separated" the sentinels, hence
+// the name "separation harness tests".
+//
+// Delayed Process Continuations can be similiarly tested,
+// as this emulates external processing servers anyway.
+// It's much simpler though, as the request is to determine if
+// a given element should be delayed or not. This could be used
+// for arbitrarily complex splitting patterns, as desired.
+func TestSeparation(t *testing.T) {
+ initRunner(t)
+
+ ws.initRPCServer()
+
+ tests := []struct {
+ name string
+ pipeline func(s beam.Scope)
+ metrics func(t *testing.T, pr beam.PipelineResult)
+ }{
+ {
+ name: "ProcessContinuations_combine_globalWindow",
+ pipeline: func(s beam.Scope) {
+ count := 10
+ imp := beam.Impulse(s)
+ out := beam.ParDo(s, &sepHarnessSdfStream{
+ Base: sepHarnessBase{
+ WatcherID: ws.newWatcher(3),
+ Sleep: time.Second,
+ IsSentinelEncoded: beam.EncodedFunc{Fn: reflectx.MakeFunc(allSentinel)},
+ LocalService: ws.serviceAddress,
+ },
+ RestSize: int64(count),
+ }, imp)
+ passert.Count(s, out, "global num ints", count)
+ },
+ }, {
+ name: "ProcessContinuations_stepped_combine_globalWindow",
+ pipeline: func(s beam.Scope) {
+ count := 10
+ imp := beam.Impulse(s)
+ out := beam.ParDo(s, &singleStepSdfStream{
+ Sleep: time.Second,
+ RestSize: int64(count),
+ }, imp)
+ passert.Count(s, out, "global stepped num ints", count)
+ sum := beam.ParDo(s, dofn2x1, imp, beam.SideInput{Input: out})
+ beam.ParDo(s, &int64Check{Name: "stepped", Want: []int{45}}, sum)
+ },
+ }, {
+ name: "ProcessContinuations_stepped_combine_fixedWindow",
+ pipeline: func(s beam.Scope) {
+ elms, mod := 1000, 10
+ count := int(elms / mod)
+ imp := beam.Impulse(s)
+ out := beam.ParDo(s, &eventtimeSDFStream{
+ Sleep: time.Second,
+ RestSize: int64(elms),
+ Mod: int64(mod),
+ Fixed: 1,
+ }, imp)
+ windowed := beam.WindowInto(s, window.NewFixedWindows(time.Second*10), out)
+ sum := stats.Sum(s, windowed)
+ // We expect each window to be processed ASAP, and produced one
+ // at a time, with the same results.
+ beam.ParDo(s, &int64Check{Name: "single", Want: []int{55}}, sum)
+ // But we need to receive the expected number of identical results
+ gsum := beam.WindowInto(s, window.NewGlobalWindows(), sum)
+ passert.Count(s, gsum, "total sums", count)
+ },
+ },
+ }
+
+ // TODO: Channel Splits
+ // TODO: SubElement/dynamic splits.
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ p, s := beam.NewPipelineWithRoot()
+ test.pipeline(s)
+ pr, err := executeWithT(context.Background(), t, p)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if test.metrics != nil {
+ test.metrics(t, pr)
+ }
+ })
+ }
+}
+
+func init() {
+ register.Function1x1(allSentinel)
+}
+
+// allSentinel indicates that all elements are sentinels.
+func allSentinel(v beam.T) bool {
+ return true
+}
+
+// Watcher is an instance of the counters.
+type watcher struct {
+ id int
+ mu sync.Mutex
+ sentinelCount, sentinelCap int
+}
+
+func (w *watcher) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.Int("id", w.id),
+ slog.Int("sentinelCount", w.sentinelCount),
+ slog.Int("sentinelCap", w.sentinelCap),
+ )
+}
+
+// Watchers is a "net/rpc" service.
+type Watchers struct {
+ mu sync.Mutex
+ nextID int
+ lookup map[int]*watcher
+ serviceOnce sync.Once
+ serviceAddress string
+}
+
+// Args is the set of parameters to the watchers RPC methdos.
+type Args struct {
+ WatcherID int
+}
+
+// Block is called once per sentinel, to indicate it will block
+// until all sentinels are blocked.
+func (ws *Watchers) Block(args *Args, _ *bool) error {
+ ws.mu.Lock()
+ defer ws.mu.Unlock()
+ w, ok := ws.lookup[args.WatcherID]
+ if !ok {
+ return fmt.Errorf("no watcher with id %v", args.WatcherID)
+ }
+ w.mu.Lock()
+ w.sentinelCount++
+ w.mu.Unlock()
+ return nil
+}
+
+// Check returns whether the sentinels are unblocked or not.
+func (ws *Watchers) Check(args *Args, unblocked *bool) error {
+ ws.mu.Lock()
+ defer ws.mu.Unlock()
+ w, ok := ws.lookup[args.WatcherID]
+ if !ok {
+ return fmt.Errorf("no watcher with id %v", args.WatcherID)
+ }
+ w.mu.Lock()
+ *unblocked = w.sentinelCount >= w.sentinelCap
+ w.mu.Unlock()
+ slog.Debug("sentinel target for watcher%d is %d/%d. unblocked=%v", args.WatcherID, w.sentinelCount, w.sentinelCap, *unblocked)
+ return nil
+}
+
+// Delay returns whether the sentinels shoudld delay.
+// This increments the sentinel cap, and returns unblocked.
+// Intended to validate ProcessContinuation behavior.
+func (ws *Watchers) Delay(args *Args, delay *bool) error {
+ ws.mu.Lock()
+ defer ws.mu.Unlock()
+ w, ok := ws.lookup[args.WatcherID]
+ if !ok {
+ return fmt.Errorf("no watcher with id %v", args.WatcherID)
+ }
+ w.mu.Lock()
+ w.sentinelCount++
+ // Delay as long as the sentinel count is under the cap.
+ *delay = w.sentinelCount < w.sentinelCap
+ w.mu.Unlock()
+ slog.Debug("Delay: sentinel target", "watcher", w, slog.Bool("delay", *delay))
+ return nil
+}
+
+func (ws *Watchers) initRPCServer() {
+ ws.serviceOnce.Do(func() {
+ l, err := net.Listen("tcp", ":0")
+ if err != nil {
+ panic(err)
+ }
+ rpc.Register(ws)
+ rpc.HandleHTTP()
+ go http.Serve(l, nil)
+ ws.serviceAddress = l.Addr().String()
+ })
+}
+
+// newWatcher starts an rpc server to maange state for watching for
+// sentinels across local machines.
+func (ws *Watchers) newWatcher(sentinelCap int) int {
+ ws.mu.Lock()
+ defer ws.mu.Unlock()
+ ws.initRPCServer()
+ if ws.lookup == nil {
+ ws.lookup = map[int]*watcher{}
+ }
+ w := &watcher{id: ws.nextID, sentinelCap: sentinelCap}
+ ws.nextID++
+ ws.lookup[w.id] = w
+ return w.id
+}
+
+// sepHarnessBase contains fields and functions that are shared by all
+// versions of the separation harness.
+type sepHarnessBase struct {
+ WatcherID int
+ Sleep time.Duration
+ IsSentinelEncoded beam.EncodedFunc
+ LocalService string
+}
+
+// One connection per binary.
+var (
+ sepClientOnce sync.Once
+ sepClient *rpc.Client
+ sepClientMu sync.Mutex
+ sepWaitMap map[int]chan struct{}
+)
+
+func (fn *sepHarnessBase) setup() error {
+ sepClientMu.Lock()
+ defer sepClientMu.Unlock()
+ sepClientOnce.Do(func() {
+ client, err := rpc.DialHTTP("tcp", fn.LocalService)
+ if err != nil {
+ slog.Error("failed to dial sentinels server", err, slog.String("endpoint", fn.LocalService))
+ panic(fmt.Sprintf("dialing sentinels server %v: %v", fn.LocalService, err))
+ }
+ sepClient = client
+ sepWaitMap = map[int]chan struct{}{}
+ })
+
+ // Check if there's alreaedy a local channel for this id, and if not
+ // start a watcher goroutine to poll and unblock the harness when
+ // the expected number of ssentinels is reached.
+ if _, ok := sepWaitMap[fn.WatcherID]; !ok {
+ return nil
+ }
+ // We need a channel to block on for this watcherID
+ // We use a channel instead of a wait group since the finished
+ // count is hosted in a different process.
+ c := make(chan struct{})
+ sepWaitMap[fn.WatcherID] = c
+ go func(id int, c chan struct{}) {
+ for {
+ time.Sleep(time.Second * 1) // Check counts every second.
+ sepClientMu.Lock()
+ var unblock bool
+ err := sepClient.Call("Watchers.Check", &Args{WatcherID: id}, &unblock)
+ if err != nil {
+ slog.Error("Watchers.Check: sentinels server error", err, slog.String("endpoint", fn.LocalService))
+ panic("sentinel server error")
+ }
+ if unblock {
+ close(c) // unblock all the local waiters.
+ slog.Debug("sentinel target for watcher, unblocking", slog.Int("watcherID", id))
+ sepClientMu.Unlock()
+ return
+ }
+ slog.Debug("sentinel target for watcher not met", slog.Int("watcherID", id))
+ sepClientMu.Unlock()
+ }
+ }(fn.WatcherID, c)
+ return nil
+}
+
+func (fn *sepHarnessBase) block() {
+ sepClientMu.Lock()
+ var ignored bool
+ err := sepClient.Call("Watchers.Block", &Args{WatcherID: fn.WatcherID}, &ignored)
+ if err != nil {
+ slog.Error("Watchers.Block error", err, slog.String("endpoint", fn.LocalService))
+ panic(err)
+ }
+ c := sepWaitMap[fn.WatcherID]
+ sepClientMu.Unlock()
+
+ // Block until the watcher closes the channel.
+ <-c
+}
+
+// delay inform the DoFn whether or not to return a delayed Processing continuation for this position.
+func (fn *sepHarnessBase) delay() bool {
+ sepClientMu.Lock()
+ defer sepClientMu.Unlock()
+ var delay bool
+ err := sepClient.Call("Watchers.Delay", &Args{WatcherID: fn.WatcherID}, &delay)
+ if err != nil {
+ slog.Error("Watchers.Delay error", err)
+ panic(err)
+ }
+ return delay
+}
+
+// sepHarness is a simple DoFn that blocks when reaching a sentinel.
+// It's useful for testing blocks on channel splits.
+type sepHarness struct {
+ Base sepHarnessBase
+}
+
+func (fn *sepHarness) Setup() error {
+ return fn.Base.setup()
+}
+
+func (fn *sepHarness) ProcessElement(v beam.T) beam.T {
+ if fn.Base.IsSentinelEncoded.Fn.Call([]any{v})[0].(bool) {
+ slog.Debug("blocking on sentinel", slog.Any("sentinel", v))
+ fn.Base.block()
+ slog.Debug("unblocking from sentinel", slog.Any("sentinel", v))
+ } else {
+ time.Sleep(fn.Base.Sleep)
+ }
+ return v
+}
+
+type sepHarnessSdf struct {
+ Base sepHarnessBase
+ RestSize int64
+}
+
+func (fn *sepHarnessSdf) Setup() error {
+ return fn.Base.setup()
+}
+
+func (fn *sepHarnessSdf) CreateInitialRestriction(v beam.T) offsetrange.Restriction {
+ return offsetrange.Restriction{Start: 0, End: fn.RestSize}
+}
+
+func (fn *sepHarnessSdf) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction {
+ return r.EvenSplits(2)
+}
+
+func (fn *sepHarnessSdf) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 {
+ return r.Size()
+}
+
+func (fn *sepHarnessSdf) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker {
+ return sdf.NewLockRTracker(offsetrange.NewTracker(r))
+}
+
+func (fn *sepHarnessSdf) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(beam.T)) {
+ i := rt.GetRestriction().(offsetrange.Restriction).Start
+ for rt.TryClaim(i) {
+ if fn.Base.IsSentinelEncoded.Fn.Call([]any{i, v})[0].(bool) {
+ slog.Debug("blocking on sentinel", slog.Group("sentinel", slog.Any("value", v), slog.Int64("pos", i)))
+ fn.Base.block()
+ slog.Debug("unblocking from sentinel", slog.Group("sentinel", slog.Any("value", v), slog.Int64("pos", i)))
+ } else {
+ time.Sleep(fn.Base.Sleep)
+ }
+ emit(v)
+ i++
+ }
+}
+
+func init() {
+ register.DoFn3x1[*sdf.LockRTracker, beam.T, func(beam.T), sdf.ProcessContinuation]((*sepHarnessSdfStream)(nil))
+ register.Emitter1[beam.T]()
+ register.DoFn3x1[*sdf.LockRTracker, beam.T, func(int64), sdf.ProcessContinuation]((*singleStepSdfStream)(nil))
+ register.Emitter1[int64]()
+ register.DoFn4x1[*CWE, *sdf.LockRTracker, beam.T, func(beam.EventTime, int64), sdf.ProcessContinuation]((*eventtimeSDFStream)(nil))
+ register.Emitter2[beam.EventTime, int64]()
+}
+
+type sepHarnessSdfStream struct {
+ Base sepHarnessBase
+ RestSize int64
+}
+
+func (fn *sepHarnessSdfStream) Setup() error {
+ return fn.Base.setup()
+}
+
+func (fn *sepHarnessSdfStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction {
+ return offsetrange.Restriction{Start: 0, End: fn.RestSize}
+}
+
+func (fn *sepHarnessSdfStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction {
+ return r.EvenSplits(2)
+}
+
+func (fn *sepHarnessSdfStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 {
+ return r.Size()
+}
+
+func (fn *sepHarnessSdfStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker {
+ return sdf.NewLockRTracker(offsetrange.NewTracker(r))
+}
+
+func (fn *sepHarnessSdfStream) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(beam.T)) sdf.ProcessContinuation {
+ if fn.Base.IsSentinelEncoded.Fn.Call([]any{v})[0].(bool) {
+ if fn.Base.delay() {
+ slog.Debug("delaying on sentinel", slog.Group("sentinel", slog.Any("value", v)))
+ return sdf.ResumeProcessingIn(fn.Base.Sleep)
+ }
+ slog.Debug("cleared to process sentinel", slog.Group("sentinel", slog.Any("value", v)))
+ }
+ r := rt.GetRestriction().(offsetrange.Restriction)
+ i := r.Start
+ for rt.TryClaim(i) {
+ emit(v)
+ i++
+ }
+ return sdf.StopProcessing()
+}
+
+// singleStepSdfStream only emits a single position at a time then sleeps.
+// Stops when a restriction of size 0 is provided.
+type singleStepSdfStream struct {
+ RestSize int64
+ Sleep time.Duration
+}
+
+func (fn *singleStepSdfStream) Setup() error {
+ return nil
+}
+
+func (fn *singleStepSdfStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction {
+ return offsetrange.Restriction{Start: 0, End: fn.RestSize}
+}
+
+func (fn *singleStepSdfStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction {
+ return r.EvenSplits(2)
+}
+
+func (fn *singleStepSdfStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 {
+ return r.Size()
+}
+
+func (fn *singleStepSdfStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker {
+ return sdf.NewLockRTracker(offsetrange.NewTracker(r))
+}
+
+func (fn *singleStepSdfStream) ProcessElement(rt *sdf.LockRTracker, v beam.T, emit func(int64)) sdf.ProcessContinuation {
+ r := rt.GetRestriction().(offsetrange.Restriction)
+ i := r.Start
+ if r.Size() < 1 {
+ slog.Debug("size 0 restriction, stoping to process sentinel", slog.Any("value", v))
+ return sdf.StopProcessing()
+ }
+ slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction",
+ slog.Any("value", v),
+ slog.Float64("size", r.Size()),
+ slog.Int64("pos", i),
+ ))
+ if rt.TryClaim(i) {
+ emit(i)
+ }
+ return sdf.ResumeProcessingIn(fn.Sleep)
+}
+
+type eventtimeSDFStream struct {
+ RestSize, Mod, Fixed int64
+ Sleep time.Duration
+}
+
+func (fn *eventtimeSDFStream) Setup() error {
+ return nil
+}
+
+func (fn *eventtimeSDFStream) CreateInitialRestriction(v beam.T) offsetrange.Restriction {
+ return offsetrange.Restriction{Start: 0, End: fn.RestSize}
+}
+
+func (fn *eventtimeSDFStream) SplitRestriction(v beam.T, r offsetrange.Restriction) []offsetrange.Restriction {
+ // No split
+ return []offsetrange.Restriction{r}
+}
+
+func (fn *eventtimeSDFStream) RestrictionSize(v beam.T, r offsetrange.Restriction) float64 {
+ return r.Size()
+}
+
+func (fn *eventtimeSDFStream) CreateTracker(r offsetrange.Restriction) *sdf.LockRTracker {
+ return sdf.NewLockRTracker(offsetrange.NewTracker(r))
+}
+
+func (fn *eventtimeSDFStream) ProcessElement(_ *CWE, rt *sdf.LockRTracker, v beam.T, emit func(beam.EventTime, int64)) sdf.ProcessContinuation {
+ r := rt.GetRestriction().(offsetrange.Restriction)
+ i := r.Start
+ if r.Size() < 1 {
+ slog.Debug("size 0 restriction, stoping to process sentinel", slog.Any("value", v))
+ return sdf.StopProcessing()
+ }
+ slog.Debug("emitting element to restriction", slog.Any("value", v), slog.Group("restriction",
+ slog.Any("value", v),
+ slog.Float64("size", r.Size()),
+ slog.Int64("pos", i),
+ ))
+ if rt.TryClaim(i) {
+ timestamp := mtime.FromMilliseconds(int64((i + 1) * 1000)).Subtract(10 * time.Millisecond)
+ v := (i % fn.Mod) + fn.Fixed
+ emit(timestamp, v)
+ }
+ return sdf.ResumeProcessingIn(fn.Sleep)
+}
+
+func (fn *eventtimeSDFStream) InitialWatermarkEstimatorState(_ beam.EventTime, _ offsetrange.Restriction, _ beam.T) int64 {
+ return int64(mtime.MinTimestamp)
+}
+
+func (fn *eventtimeSDFStream) CreateWatermarkEstimator(initialState int64) *CWE {
+ return &CWE{Watermark: initialState}
+}
+
+func (fn *eventtimeSDFStream) WatermarkEstimatorState(e *CWE) int64 {
+ return e.Watermark
+}
+
+type CWE struct {
+ Watermark int64 // uses int64, since the SDK prevent mtime.Time from serialization.
+}
+
+func (e *CWE) CurrentWatermark() time.Time {
+ return mtime.Time(e.Watermark).ToTime()
+}
+
+func (e *CWE) ObserveTimestamp(ts time.Time) {
+ // We add 10 milliseconds to allow window boundaries to
+ // progress after emitting
+ e.Watermark = int64(mtime.FromTime(ts.Add(-90 * time.Millisecond)))
+}
diff --git a/sdks/go/pkg/beam/runners/prism/prism.go b/sdks/go/pkg/beam/runners/prism/prism.go
new file mode 100644
index 00000000000..dc78e5e6c23
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/prism.go
@@ -0,0 +1,48 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package prism contains a local runner for running
+// pipelines in the current process. Useful for testing.
+package prism
+
+import (
+ "context"
+
+ "github.com/apache/beam/sdks/v2/go/pkg/beam"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal"
+)
+
+func init() {
+ beam.RegisterRunner("prism", Execute)
+ beam.RegisterRunner("PrismRunner", Execute)
+}
+
+func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) {
+ if *jobopts.Endpoint == "" {
+ // One hasn't been selected, so lets start one up and set the address.
+ // Conveniently, this means that if multiple pipelines are executed against
+ // the local runner, they will all use the same server.
+ s := jobservices.NewServer(0, internal.RunPipeline)
+ *jobopts.Endpoint = s.Endpoint()
+ go s.Serve()
+ }
+ if !jobopts.IsLoopback() {
+ *jobopts.EnvironmentType = "loopback"
+ }
+ return universal.Execute(ctx, p)
+}