You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/08/23 19:38:49 UTC
[beam] branch master updated: Add bag state support (#22816)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 702e7373c6f Add bag state support (#22816)
702e7373c6f is described below
commit 702e7373c6f90d850e1ee032c732d77bf12c292e
Author: Danny McCormick <da...@google.com>
AuthorDate: Tue Aug 23 15:38:41 2022 -0400
Add bag state support (#22816)
* Add bag state support
* Fix comment
---
sdks/go/pkg/beam/core/graph/fn.go | 9 +-
sdks/go/pkg/beam/core/graph/fn_test.go | 11 +-
sdks/go/pkg/beam/core/runtime/exec/translate.go | 7 +-
sdks/go/pkg/beam/core/runtime/exec/userstate.go | 59 +++++++++
.../pkg/beam/core/runtime/exec/userstate_test.go | 1 +
sdks/go/pkg/beam/core/runtime/graphx/translate.go | 38 ++++--
sdks/go/pkg/beam/core/state/state.go | 76 +++++++++++
sdks/go/pkg/beam/core/state/state_test.go | 147 ++++++++++++++++++++-
8 files changed, 324 insertions(+), 24 deletions(-)
diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 8ac30dc6ec7..78aab66de81 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -1274,11 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
- // TODO(#22736) - Add more state types as they become supported
- if t != state.StateTypeValue {
- err := errors.Errorf("Non-value state type %v for state %v", t, s)
- return errors.SetTopLevelMsgf(err, "Non-value state type %v for state %v. Currently the only supported state"+
- "type is state.Value", t, s)
+ if t != state.StateTypeValue && t != state.StateTypeBag {
+ err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
+ return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
+ "type is state.Value and state.Bag", t, s)
}
stateKeys[k] = s
}
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index f672fa976b0..c2727298b0f 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -53,7 +53,8 @@ func TestNewDoFn(t *testing.T) {
{dfn: &GoodDoFnCoGbk2{}, opt: CoGBKMainInput(3)},
{dfn: &GoodDoFnCoGbk7{}, opt: CoGBKMainInput(8)},
{dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)},
- {dfn: &GoodStatefulDoFn{State1: state.Value[int](state.MakeValueState[int]("state1"))}, opt: NumMainInputs(MainKv)},
+ {dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)},
+ {dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)},
}
for _, test := range tests {
@@ -1087,6 +1088,14 @@ func (fn *GoodStatefulDoFn) ProcessElement(state.Provider, int, int) int {
return 0
}
+type GoodStatefulDoFn2 struct {
+ State1 state.Bag[int]
+}
+
+func (fn *GoodStatefulDoFn2) ProcessElement(state.Provider, int, int) int {
+ return 0
+}
+
// Examples of incorrect SDF signatures.
// Examples with missing methods.
diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go
index dc9965a4061..65b0de9d48d 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go
@@ -469,7 +469,12 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
stateIDToCoder := make(map[string]*coder.Coder)
for key, spec := range userState {
// TODO(#22736) - this will eventually need to be aware of which type of state its modifying to support non-Value state types.
- cID := spec.GetReadModifyWriteSpec().CoderId
+ var cID string
+ if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
+ cID = rmw.CoderId
+ } else if bs := spec.GetBagSpec(); bs != nil {
+ cID = bs.ElementCoderId
+ }
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
index 82d54d92496..21fab583b8e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
@@ -34,6 +34,7 @@ type stateProvider struct {
transactionsByKey map[string][]state.Transaction
initialValueByKey map[string]interface{}
+ initialBagByKey map[string][]interface{}
readersByKey map[string]io.ReadCloser
appendersByKey map[string]io.Writer
clearersByKey map[string]io.Writer
@@ -57,6 +58,7 @@ func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state
return nil, []state.Transaction{}, nil
}
initialValue = resp.Elm
+ s.initialValueByKey[userStateID] = initialValue
}
transactions, ok := s.transactionsByKey[userStateID]
@@ -101,6 +103,62 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
return nil
}
+// ReadBagState reads a ReadBagState state from the State API
+func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state.Transaction, error) {
+ initialValue, ok := s.initialBagByKey[userStateID]
+ if !ok {
+ initialValue = []interface{}{}
+ rw, err := s.getReader(userStateID)
+ if err != nil {
+ return nil, nil, err
+ }
+ dec := MakeElementDecoder(coder.SkipW(s.codersByKey[userStateID]))
+ for err == nil {
+ var resp *FullValue
+ resp, err = dec.Decode(rw)
+ if err == nil {
+ initialValue = append(initialValue, resp.Elm)
+ } else if err != io.EOF {
+ return nil, nil, err
+ }
+ }
+ s.initialBagByKey[userStateID] = initialValue
+ }
+
+ transactions, ok := s.transactionsByKey[userStateID]
+ if !ok {
+ transactions = []state.Transaction{}
+ }
+
+ return initialValue, transactions, nil
+}
+
+// WriteValueState writes a value state to the State API
+// For value states, this is done by clearing a bag state and writing a value to it.
+func (s *stateProvider) WriteBagState(val state.Transaction) error {
+ ap, err := s.getAppender(val.Key)
+ if err != nil {
+ return err
+ }
+ fv := FullValue{Elm: val.Val}
+ // TODO(#22736) - consider caching this a proprty of stateProvider
+ enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
+ err = enc.Encode(&fv, ap)
+ if err != nil {
+ return err
+ }
+
+ // TODO(#22736) - optimize this a bit once all state types are added.
+ if transactions, ok := s.transactionsByKey[val.Key]; ok {
+ transactions = append(transactions, val)
+ s.transactionsByKey[val.Key] = transactions
+ } else {
+ s.transactionsByKey[val.Key] = []state.Transaction{val}
+ }
+
+ return nil
+}
+
func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
@@ -189,6 +247,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
window: win,
transactionsByKey: make(map[string][]state.Transaction),
initialValueByKey: make(map[string]interface{}),
+ initialBagByKey: make(map[string][]interface{}),
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
index f75ec09f9b9..a4091bc7d65 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
@@ -82,6 +82,7 @@ func buildStateProvider() stateProvider {
window: []byte{1},
transactionsByKey: make(map[string][]state.Transaction),
initialValueByKey: make(map[string]interface{}),
+ initialBagByKey: make(map[string][]interface{}),
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
index 975fdf4f13d..4d113ec2202 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
@@ -26,6 +26,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window/trigger"
v1pb "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx/v1"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
@@ -84,6 +85,9 @@ const (
URNEnvProcess = "beam:env:process:v1"
URNEnvExternal = "beam:env:external:v1"
URNEnvDocker = "beam:env:docker:v1"
+
+ // Userstate Urns.
+ URNBagUserState = "beam:user_state:bag:v1"
)
func goCapabilities() []string {
@@ -466,17 +470,31 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
if err != nil {
return handleErr(err)
}
- stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
- // TODO (#22736) - make spec type and protocol conditional on type of State. Right now, assumes ValueState.
- // See https://github.com/apache/beam/blob/54b0784da7ccba738deff22bd83fbc374ad21d2e/sdks/go/pkg/beam/model/pipeline_v1/beam_runner_api.pb.go#L2635
- Spec: &pipepb.StateSpec_ReadModifyWriteSpec{
- ReadModifyWriteSpec: &pipepb.ReadModifyWriteStateSpec{
- CoderId: coderID,
+ switch ps.StateType() {
+ case state.StateTypeValue:
+ stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
+ Spec: &pipepb.StateSpec_ReadModifyWriteSpec{
+ ReadModifyWriteSpec: &pipepb.ReadModifyWriteStateSpec{
+ CoderId: coderID,
+ },
},
- },
- Protocol: &pipepb.FunctionSpec{
- Urn: "beam:user_state:bag:v1",
- },
+ Protocol: &pipepb.FunctionSpec{
+ Urn: URNBagUserState,
+ },
+ }
+ case state.StateTypeBag:
+ stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
+ Spec: &pipepb.StateSpec_BagSpec{
+ BagSpec: &pipepb.BagStateSpec{
+ ElementCoderId: coderID,
+ },
+ },
+ Protocol: &pipepb.FunctionSpec{
+ Urn: URNBagUserState,
+ },
+ }
+ default:
+ return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
}
payload.StateSpecs = stateSpecs
diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go
index 7ac9efaff9b..8877260afb8 100644
--- a/sdks/go/pkg/beam/core/state/state.go
+++ b/sdks/go/pkg/beam/core/state/state.go
@@ -31,8 +31,12 @@ const (
TransactionTypeSet TransactionTypeEnum = 0
// TransactionTypeClear is the set transaction type
TransactionTypeClear TransactionTypeEnum = 1
+ // TransactionTypeAppend is the append transaction type
+ TransactionTypeAppend TransactionTypeEnum = 2
// StateTypeValue represents a value state
StateTypeValue StateTypeEnum = 0
+ // StateTypeBag represents a bag state
+ StateTypeBag StateTypeEnum = 1
)
var (
@@ -57,6 +61,8 @@ type Transaction struct {
type Provider interface {
ReadValueState(id string) (interface{}, []Transaction, error)
WriteValueState(val Transaction) error
+ ReadBagState(id string) ([]interface{}, []Transaction, error)
+ WriteBagState(val Transaction) error
}
// PipelineState is an interface representing different kinds of PipelineState (currently just state.Value).
@@ -133,3 +139,73 @@ func MakeValueState[T any](k string) Value[T] {
Key: k,
}
}
+
+// Bag is used to read and write global pipeline state representing a collection of values.
+// Key represents the key used to lookup this state.
+type Bag[T any] struct {
+ Key string
+}
+
+// Add is used to write append to the bag pipeline state.
+func (s *Bag[T]) Add(p Provider, val T) error {
+ return p.WriteBagState(Transaction{
+ Key: s.Key,
+ Type: TransactionTypeAppend,
+ Val: val,
+ })
+}
+
+// Read is used to read this instance of global pipeline state representing a bag.
+// When a value is not found, returns an empty list and false.
+func (s *Bag[T]) Read(p Provider) ([]T, bool, error) {
+ // This replays any writes that have happened to this value since we last read
+ // For more detail, see "State Transactionality" below for buffered transactions
+ initialValue, bufferedTransactions, err := p.ReadBagState(s.Key)
+ if err != nil {
+ var val []T
+ return val, false, err
+ }
+ cur := []T{}
+ for _, v := range initialValue {
+ cur = append(cur, v.(T))
+ }
+ for _, t := range bufferedTransactions {
+ switch t.Type {
+ case TransactionTypeAppend:
+ cur = append(cur, t.Val.(T))
+ case TransactionTypeClear:
+ cur = []T{}
+ }
+ }
+ if len(cur) == 0 {
+ return cur, false, nil
+ }
+ return cur, true, nil
+}
+
+// StateKey returns the key for this pipeline state entry.
+func (s Bag[T]) StateKey() string {
+ if s.Key == "" {
+ // TODO(#22736) - infer the state from the member variable name during pipeline construction.
+ panic("Value state exists on struct but has not been initialized with a key.")
+ }
+ return s.Key
+}
+
+// CoderType returns the type of the bag state which should be used for a coder.
+func (s Bag[T]) CoderType() reflect.Type {
+ var t T
+ return reflect.TypeOf(t)
+}
+
+// StateType returns the type of the state (in this case always Bag).
+func (s Bag[T]) StateType() StateTypeEnum {
+ return StateTypeBag
+}
+
+// MakeBagState is a factory function to create an instance of BagState with the given key.
+func MakeBagState[T any](k string) Bag[T] {
+ return Bag[T]{
+ Key: k,
+ }
+}
diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go
index 536fa00a60e..c8924560ad6 100644
--- a/sdks/go/pkg/beam/core/state/state_test.go
+++ b/sdks/go/pkg/beam/core/state/state_test.go
@@ -25,9 +25,10 @@ var (
)
type fakeProvider struct {
- initialState map[string]interface{}
- transactions map[string][]Transaction
- err map[string]error
+ initialState map[string]interface{}
+ initialBagState map[string][]interface{}
+ transactions map[string][]Transaction
+ err map[string]error
}
func (s *fakeProvider) ReadValueState(userStateID string) (interface{}, []Transaction, error) {
@@ -51,6 +52,27 @@ func (s *fakeProvider) WriteValueState(val Transaction) error {
return nil
}
+func (s *fakeProvider) ReadBagState(userStateID string) ([]interface{}, []Transaction, error) {
+ if err, ok := s.err[userStateID]; ok {
+ return nil, nil, err
+ }
+ base := s.initialBagState[userStateID]
+ trans, ok := s.transactions[userStateID]
+ if !ok {
+ trans = []Transaction{}
+ }
+ return base, trans, nil
+}
+
+func (s *fakeProvider) WriteBagState(val Transaction) error {
+ if transactions, ok := s.transactions[val.Key]; ok {
+ s.transactions[val.Key] = append(transactions, val)
+ } else {
+ s.transactions[val.Key] = []Transaction{val}
+ }
+ return nil
+}
+
func TestValueRead(t *testing.T) {
is := make(map[string]interface{})
ts := make(map[string][]Transaction)
@@ -128,13 +150,124 @@ func TestValueWrite(t *testing.T) {
}
val, ok, err := vs.Read(&f)
if err != nil {
- t.Errorf("Value.Read() returned error %v when it shouldn't have after writing: %v", err, tt.writes)
+ t.Errorf("Value.Write() returned error %v when it shouldn't have after writing: %v", err, tt.writes)
} else if ok && !tt.ok {
- t.Errorf("Value.Read() returned a value %v when it shouldn't have after writing: %v", val, tt.writes)
+ t.Errorf("Value.Write() returned a value %v when it shouldn't have after writing: %v", val, tt.writes)
} else if !ok && tt.ok {
- t.Errorf("Value.Read() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes)
+ t.Errorf("Value.Write() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes)
} else if val != tt.val {
- t.Errorf("Value.Read()=%v, want %v after writing: %v", val, tt.val, tt.writes)
+ t.Errorf("Value.Write()=%v, want %v after writing: %v", val, tt.val, tt.writes)
+ }
+ }
+}
+
+func TestBagRead(t *testing.T) {
+ is := make(map[string][]interface{})
+ ts := make(map[string][]Transaction)
+ es := make(map[string]error)
+ is["no_transactions"] = []interface{}{1}
+ ts["no_transactions"] = nil
+ is["basic_append"] = []interface{}{}
+ ts["basic_append"] = []Transaction{{Key: "basic_append", Type: TransactionTypeAppend, Val: 3}}
+ is["multi_append"] = []interface{}{}
+ ts["multi_append"] = []Transaction{{Key: "multi_append", Type: TransactionTypeAppend, Val: 3}, {Key: "multi_append", Type: TransactionTypeAppend, Val: 2}}
+ is["basic_clear"] = []interface{}{1}
+ ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil}}
+ is["append_then_clear"] = []interface{}{1}
+ ts["append_then_clear"] = []Transaction{{Key: "append_then_clear", Type: TransactionTypeAppend, Val: 3}, {Key: "append_then_clear", Type: TransactionTypeClear, Val: nil}}
+ is["append_then_clear_then_append"] = []interface{}{1}
+ ts["append_then_clear_then_append"] = []Transaction{{Key: "append_then_clear_then_append", Type: TransactionTypeAppend, Val: 3}, {Key: "append_then_clear_then_append", Type: TransactionTypeClear, Val: nil}, {Key: "append_then_clear_then_append", Type: TransactionTypeAppend, Val: 4}}
+ is["err"] = []interface{}{1}
+ ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeAppend, Val: 3}}
+ es["err"] = errFake
+
+ f := fakeProvider{
+ initialBagState: is,
+ transactions: ts,
+ err: es,
+ }
+
+ var tests = []struct {
+ vs Bag[int]
+ val []int
+ ok bool
+ err error
+ }{
+ {MakeBagState[int]("no_transactions"), []int{1}, true, nil},
+ {MakeBagState[int]("basic_append"), []int{3}, true, nil},
+ {MakeBagState[int]("multi_append"), []int{3, 2}, true, nil},
+ {MakeBagState[int]("basic_clear"), []int{}, false, nil},
+ {MakeBagState[int]("append_then_clear"), []int{}, false, nil},
+ {MakeBagState[int]("append_then_clear_then_append"), []int{4}, true, nil},
+ {MakeBagState[int]("err"), []int{}, false, errFake},
+ }
+
+ for _, tt := range tests {
+ val, ok, err := tt.vs.Read(&f)
+ if err != nil && tt.err == nil {
+ t.Errorf("Bag.Read() returned error %v for state key %v when it shouldn't have", err, tt.vs.Key)
+ } else if err == nil && tt.err != nil {
+ t.Errorf("Bag.Read() returned no error for state key %v when it should have returned %v", tt.vs.Key, err)
+ } else if ok && !tt.ok {
+ t.Errorf("Bag.Read() returned a value %v for state key %v when it shouldn't have", val, tt.vs.Key)
+ } else if !ok && tt.ok {
+ t.Errorf("Bag.Read() didn't return a value for state key %v when it should have returned %v", tt.vs.Key, tt.val)
+ } else if len(val) != len(tt.val) {
+ t.Errorf("Bag.Read()=%v, want %v for state key %v", val, tt.val, tt.vs.Key)
+ } else {
+ eq := true
+ for idx, v := range val {
+ if v != tt.val[idx] {
+ eq = false
+ }
+ }
+ if !eq {
+ t.Errorf("Bag.Read()=%v, want %v for state key %v", val, tt.val, tt.vs.Key)
+ }
+ }
+ }
+}
+
+func TestBagWrite(t *testing.T) {
+ var tests = []struct {
+ writes []int
+ val []int
+ ok bool
+ }{
+ {[]int{}, []int{}, false},
+ {[]int{3}, []int{3}, true},
+ {[]int{1, 5}, []int{1, 5}, true},
+ }
+
+ for _, tt := range tests {
+ f := fakeProvider{
+ initialState: make(map[string]interface{}),
+ transactions: make(map[string][]Transaction),
+ err: make(map[string]error),
+ }
+ vs := MakeBagState[int]("vs")
+ for _, val := range tt.writes {
+ vs.Add(&f, val)
+ }
+ val, ok, err := vs.Read(&f)
+ if err != nil {
+ t.Errorf("Bag.Write() returned error %v when it shouldn't have after writing: %v", err, tt.writes)
+ } else if ok && !tt.ok {
+ t.Errorf("Bag.Write() returned a value %v when it shouldn't have after writing: %v", val, tt.writes)
+ } else if !ok && tt.ok {
+ t.Errorf("Bag.Write() didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes)
+ } else if len(val) != len(tt.val) {
+ t.Errorf("Bag.Write()=%v, want %v after writing: %v", val, tt.val, tt.writes)
+ } else {
+ eq := true
+ for idx, v := range val {
+ if v != tt.val[idx] {
+ eq = false
+ }
+ }
+ if !eq {
+ t.Errorf("Bag.Write()=%v, want %v after writing: %v", val, tt.val, tt.writes)
+ }
}
}
}