You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/08/26 18:50:19 UTC
[beam] branch master updated: Add map state in the Go Sdk (#22897)
This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 0c82583d6ac Add map state in the Go Sdk (#22897)
0c82583d6ac is described below
commit 0c82583d6ac5c60490d60dd5c28419f2ee524103
Author: Danny McCormick <da...@google.com>
AuthorDate: Fri Aug 26 14:50:08 2022 -0400
Add map state in the Go Sdk (#22897)
* Add map state in the Go Sdk
* Remove unused function for now
* Comment fixes
* Update sdks/go/pkg/beam/core/runtime/graphx/translate.go
Co-authored-by: Ritesh Ghorse <ri...@gmail.com>
Co-authored-by: Ritesh Ghorse <ri...@gmail.com>
---
sdks/go/pkg/beam/core/graph/fn.go | 4 +-
sdks/go/pkg/beam/core/graph/fn_test.go | 9 +
sdks/go/pkg/beam/core/runtime/exec/data.go | 10 +
.../pkg/beam/core/runtime/exec/sideinput_test.go | 25 +++
sdks/go/pkg/beam/core/runtime/exec/translate.go | 14 +-
sdks/go/pkg/beam/core/runtime/exec/userstate.go | 233 ++++++++++++++++++---
sdks/go/pkg/beam/core/runtime/graphx/translate.go | 36 +++-
sdks/go/pkg/beam/core/runtime/harness/statemgr.go | 116 ++++++++++
sdks/go/pkg/beam/core/state/state.go | 154 +++++++++++++-
sdks/go/pkg/beam/core/state/state_test.go | 209 ++++++++++++++++++
sdks/go/pkg/beam/pardo.go | 11 +-
11 files changed, 778 insertions(+), 43 deletions(-)
diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 837c8a41377..a78ae73db03 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
- if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining {
+ if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeMap {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
- "types are state.Value, state.Combining, and state.Bag", t, s)
+ "types are state.Value, state.Combining, state.Bag, and state.Map", t, s)
}
stateKeys[k] = s
}
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index 19647d88cbb..0d9a9e744d2 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -58,6 +58,7 @@ func TestNewDoFn(t *testing.T) {
{dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int {
return a * b
})}, opt: NumMainInputs(MainKv)},
+ {dfn: &GoodStatefulDoFn4{State1: state.MakeMapState[string, int]("state1")}, opt: NumMainInputs(MainKv)},
}
for _, test := range tests {
@@ -1107,6 +1108,14 @@ func (fn *GoodStatefulDoFn3) ProcessElement(state.Provider, int, int) int {
return 0
}
+type GoodStatefulDoFn4 struct {
+ State1 state.Map[string, int]
+}
+
+func (fn *GoodStatefulDoFn4) ProcessElement(state.Provider, int, int) int {
+ return 0
+}
+
// Examples of incorrect SDF signatures.
// Examples with missing methods.
diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go
index b6b4727d742..fdc1e368a52 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/data.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/data.go
@@ -77,6 +77,16 @@ type StateReader interface {
OpenBagUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
// OpenBagUserStateClearer opens a byte stream for clearing user bag state.
OpenBagUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
+ // OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+ OpenMultimapUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error)
+ // OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+ OpenMultimapUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error)
+ // OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+ OpenMultimapUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error)
+ // OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+ OpenMultimapKeysUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error)
+ // OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+ OpenMultimapKeysUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
// GetSideInputCache returns the SideInputCache being used at the harness level.
GetSideInputCache() SideCache
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
index 06098fd8ea5..d045c9e36dd 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
@@ -148,6 +148,31 @@ func (t *testStateReader) OpenBagUserStateClearer(ctx context.Context, id Stream
return nil, nil
}
+// OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+func (s *testStateReader) OpenMultimapUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error) {
+ return nil, nil
+}
+
+// OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+func (s *testStateReader) OpenMultimapUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+ return nil, nil
+}
+
+// OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+func (s *testStateReader) OpenMultimapUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+ return nil, nil
+}
+
+// OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+func (s *testStateReader) OpenMultimapKeysUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error) {
+ return nil, nil
+}
+
+// OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+func (s *testStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) {
+ return nil, nil
+}
+
func (t *testStateReader) GetSideInputCache() SideCache {
return &testSideCache{}
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go
index fc4844010fa..abae7d2d080 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go
@@ -467,9 +467,11 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
if len(userState) > 0 {
stateIDToCoder := make(map[string]*coder.Coder)
+ stateIDToKeyCoder := make(map[string]*coder.Coder)
stateIDToCombineFn := make(map[string]*graph.CombineFn)
for key, spec := range userState {
var cID string
+ var kcID string
if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
cID = rmw.CoderId
} else if bs := spec.GetBagSpec(); bs != nil {
@@ -490,11 +492,21 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
return nil, err
}
stateIDToCombineFn[key] = cfn
+ } else if ms := spec.GetMapSpec(); ms != nil {
+ cID = ms.ValueCoderId
+ kcID = ms.KeyCoderId
}
c, err := b.coders.Coder(cID)
if err != nil {
return nil, err
}
+ if kcID != "" {
+ kc, err := b.coders.Coder(kcID)
+ if err != nil {
+ return nil, err
+ }
+ stateIDToKeyCoder[key] = kc
+ }
stateIDToCoder[key] = c
sid := StreamID{
Port: Port{URL: b.desc.GetStateApiServiceDescriptor().GetUrl()},
@@ -505,7 +517,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
if err != nil {
return nil, err
}
- n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToCombineFn)
+ n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToKeyCoder, stateIDToCombineFn)
}
}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
index 2587530c838..a980d2bb3bb 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
@@ -16,6 +16,7 @@
package exec
import (
+ "bytes"
"context"
"fmt"
"io"
@@ -34,21 +35,24 @@ type stateProvider struct {
elementKey []byte
window []byte
- transactionsByKey map[string][]state.Transaction
- initialValueByKey map[string]interface{}
- initialBagByKey map[string][]interface{}
- readersByKey map[string]io.ReadCloser
- appendersByKey map[string]io.Writer
- clearersByKey map[string]io.Writer
- codersByKey map[string]*coder.Coder
- combineFnsByKey map[string]*graph.CombineFn
+ transactionsByKey map[string][]state.Transaction
+ initialValueByKey map[string]interface{}
+ initialBagByKey map[string][]interface{}
+ initialMapValuesByKey map[string]map[string]interface{}
+ initialMapKeysByKey map[string][]interface{}
+ readersByKey map[string]io.ReadCloser
+ appendersByKey map[string]io.Writer
+ clearersByKey map[string]io.Writer
+ codersByKey map[string]*coder.Coder
+ keyCodersByID map[string]*coder.Coder
+ combineFnsByKey map[string]*graph.CombineFn
}
// ReadValueState reads a value state from the State API
func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state.Transaction, error) {
initialValue, ok := s.initialValueByKey[userStateID]
if !ok {
- rw, err := s.getReader(userStateID)
+ rw, err := s.getBagReader(userStateID)
if err != nil {
return nil, nil, err
}
@@ -75,13 +79,13 @@ func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state
// WriteValueState writes a value state to the State API
// For value states, this is done by clearing a bag state and writing a value to it.
func (s *stateProvider) WriteValueState(val state.Transaction) error {
- cl, err := s.getClearer(val.Key)
+ cl, err := s.getBagClearer(val.Key)
if err != nil {
return err
}
cl.Write([]byte{})
- ap, err := s.getAppender(val.Key)
+ ap, err := s.getBagAppender(val.Key)
if err != nil {
return err
}
@@ -106,12 +110,12 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
return nil
}
-// ReadBagState reads a ReadBagState state from the State API
+// ReadBagState reads a bag state from the State API
func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state.Transaction, error) {
initialValue, ok := s.initialBagByKey[userStateID]
if !ok {
initialValue = []interface{}{}
- rw, err := s.getReader(userStateID)
+ rw, err := s.getBagReader(userStateID)
if err != nil {
return nil, nil, err
}
@@ -136,10 +140,9 @@ func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state
return initialValue, transactions, nil
}
-// WriteValueState writes a value state to the State API
-// For value states, this is done by clearing a bag state and writing a value to it.
+// WriteBagState writes a bag state to the State API
func (s *stateProvider) WriteBagState(val state.Transaction) error {
- ap, err := s.getAppender(val.Key)
+ ap, err := s.getBagAppender(val.Key)
if err != nil {
return err
}
@@ -162,6 +165,105 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error {
return nil
}
+// ReadMapStateValue reads a value from the map state for a given key.
+func (s *stateProvider) ReadMapStateValue(userStateID string, key interface{}) (interface{}, []state.Transaction, error) {
+ _, ok := s.initialMapValuesByKey[userStateID]
+ if !ok {
+ s.initialMapValuesByKey[userStateID] = make(map[string]interface{})
+ }
+ b, err := s.encodeKey(userStateID, key)
+ if err != nil {
+ return nil, nil, err
+ }
+ initialValue, ok := s.initialMapValuesByKey[userStateID][string(b)]
+ if !ok {
+ rw, err := s.getMultiMapReader(userStateID, key)
+ if err != nil {
+ return nil, nil, err
+ }
+ dec := MakeElementDecoder(coder.SkipW(s.codersByKey[userStateID]))
+ resp, err := dec.Decode(rw)
+ if err != nil && err != io.EOF {
+ return nil, nil, err
+ }
+ if resp == nil {
+ return nil, []state.Transaction{}, nil
+ }
+ initialValue = resp.Elm
+ s.initialValueByKey[userStateID] = initialValue
+ }
+
+ transactions, ok := s.transactionsByKey[userStateID]
+ if !ok {
+ transactions = []state.Transaction{}
+ }
+
+ return initialValue, transactions, nil
+}
+
+// ReadMapStateKeys reads all the keys in a map state.
+func (s *stateProvider) ReadMapStateKeys(userStateID string) ([]interface{}, []state.Transaction, error) {
+ initialValue, ok := s.initialMapKeysByKey[userStateID]
+ if !ok {
+ initialValue = []interface{}{}
+ rw, err := s.getMultiMapKeyReader(userStateID)
+ if err != nil {
+ return nil, nil, err
+ }
+ dec := MakeElementDecoder(coder.SkipW(s.keyCodersByID[userStateID]))
+ for err == nil {
+ var resp *FullValue
+ resp, err = dec.Decode(rw)
+ if err == nil {
+ initialValue = append(initialValue, resp.Elm)
+ } else if err != io.EOF {
+ return nil, nil, err
+ }
+ }
+ s.initialMapKeysByKey[userStateID] = initialValue
+ }
+
+ transactions, ok := s.transactionsByKey[userStateID]
+ if !ok {
+ transactions = []state.Transaction{}
+ }
+
+ return initialValue, transactions, nil
+}
+
+// WriteMapState writes a key value pair to the global map state.
+func (s *stateProvider) WriteMapState(val state.Transaction) error {
+ cl, err := s.getMultiMapClearer(val.Key, val.MapKey)
+ if err != nil {
+ return err
+ }
+ cl.Write([]byte{})
+
+ ap, err := s.getMultiMapAppender(val.Key, val.MapKey)
+ if err != nil {
+ return err
+ }
+ fv := FullValue{Elm: val.Val}
+ // TODO(#22736) - consider caching this a proprty of stateProvider
+ enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
+ err = enc.Encode(&fv, ap)
+ if err != nil {
+ return err
+ }
+
+ // TODO(#22736) - optimize this a bit once all state types are added. In the case of sets/clears,
+ // we can remove the transactions. We can also consider combining other transactions on read (or sooner)
+ // so that we don't need to use as much memory/time replaying transactions.
+ if transactions, ok := s.transactionsByKey[val.Key]; ok {
+ transactions = append(transactions, val)
+ s.transactionsByKey[val.Key] = transactions
+ } else {
+ s.transactionsByKey[val.Key] = []state.Transaction{val}
+ }
+
+ return nil
+}
+
func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ca := a.CreateAccumulatorFn(); ca != nil {
@@ -196,7 +298,7 @@ func (s *stateProvider) ExtractOutputFn(userStateID string) reflectx.Func {
return nil
}
-func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
+func (s *stateProvider) getBagReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
}
@@ -208,7 +310,7 @@ func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
return s.readersByKey[userStateID], nil
}
-func (s *stateProvider) getAppender(userStateID string) (io.Writer, error) {
+func (s *stateProvider) getBagAppender(userStateID string) (io.Writer, error) {
if w, ok := s.appendersByKey[userStateID]; ok {
return w, nil
}
@@ -220,7 +322,7 @@ func (s *stateProvider) getAppender(userStateID string) (io.Writer, error) {
return s.appendersByKey[userStateID], nil
}
-func (s *stateProvider) getClearer(userStateID string) (io.Writer, error) {
+func (s *stateProvider) getBagClearer(userStateID string) (io.Writer, error) {
if w, ok := s.clearersByKey[userStateID]; ok {
return w, nil
}
@@ -232,6 +334,65 @@ func (s *stateProvider) getClearer(userStateID string) (io.Writer, error) {
return s.clearersByKey[userStateID], nil
}
+func (s *stateProvider) getMultiMapReader(userStateID string, key interface{}) (io.ReadCloser, error) {
+ ek, err := s.encodeKey(userStateID, key)
+ if err != nil {
+ return nil, err
+ }
+ r, err := s.sr.OpenMultimapUserStateReader(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+ if err != nil {
+ return nil, err
+ }
+ return r, nil
+}
+
+func (s *stateProvider) getMultiMapAppender(userStateID string, key interface{}) (io.Writer, error) {
+ ek, err := s.encodeKey(userStateID, key)
+ if err != nil {
+ return nil, err
+ }
+ w, err := s.sr.OpenMultimapUserStateAppender(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+ if err != nil {
+ return nil, err
+ }
+ return w, nil
+}
+
+func (s *stateProvider) getMultiMapClearer(userStateID string, key interface{}) (io.Writer, error) {
+ ek, err := s.encodeKey(userStateID, key)
+ if err != nil {
+ return nil, err
+ }
+ w, err := s.sr.OpenMultimapUserStateClearer(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+ if err != nil {
+ return nil, err
+ }
+ return w, nil
+}
+
+func (s *stateProvider) getMultiMapKeyReader(userStateID string) (io.ReadCloser, error) {
+ if r, ok := s.readersByKey[userStateID]; ok {
+ return r, nil
+ }
+ r, err := s.sr.OpenMultimapKeysUserStateReader(s.ctx, s.SID, userStateID, s.elementKey, s.window)
+ if err != nil {
+ return nil, err
+ }
+ s.readersByKey[userStateID] = r
+ return s.readersByKey[userStateID], nil
+}
+
+func (s *stateProvider) encodeKey(userStateID string, key interface{}) ([]byte, error) {
+ fv := FullValue{Elm: key}
+ enc := MakeElementEncoder(coder.SkipW(s.keyCodersByID[userStateID]))
+ var b bytes.Buffer
+ err := enc.Encode(&fv, &b)
+ if err != nil {
+ return nil, err
+ }
+ return b.Bytes(), nil
+}
+
// UserStateAdapter provides a state provider to be used for user state.
type UserStateAdapter interface {
NewStateProvider(ctx context.Context, reader StateReader, w typex.Window, element interface{}) (stateProvider, error)
@@ -242,13 +403,14 @@ type userStateAdapter struct {
wc WindowEncoder
kc ElementEncoder
stateIDToCoder map[string]*coder.Coder
+ stateIDToKeyCoder map[string]*coder.Coder
stateIDToCombineFn map[string]*graph.CombineFn
c *coder.Coder
}
// NewUserStateAdapter returns a user state adapter for the given StreamID and coder.
// It expects a W<V> or W<KV<K,V>> coder, because the protocol requires windowing information.
-func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
+func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToKeyCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
if !coder.IsW(c) {
panic(fmt.Sprintf("expected WV coder for user state %v: %v", sid, c))
}
@@ -258,7 +420,7 @@ func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string
if coder.IsKV(coder.SkipW(c)) {
kc = MakeElementEncoder(coder.SkipW(c).Components[0])
}
- return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToCombineFn: stateIDToCombineFn}
+ return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToKeyCoder: stateIDToKeyCoder, stateIDToCombineFn: stateIDToCombineFn}
}
// NewStateProvider creates a stateProvider with the ability to talk to the state API.
@@ -276,19 +438,22 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
return stateProvider{}, err
}
sp := stateProvider{
- ctx: ctx,
- sr: reader,
- SID: s.sid,
- elementKey: elementKey,
- window: win,
- transactionsByKey: make(map[string][]state.Transaction),
- initialValueByKey: make(map[string]interface{}),
- initialBagByKey: make(map[string][]interface{}),
- readersByKey: make(map[string]io.ReadCloser),
- appendersByKey: make(map[string]io.Writer),
- clearersByKey: make(map[string]io.Writer),
- combineFnsByKey: s.stateIDToCombineFn,
- codersByKey: s.stateIDToCoder,
+ ctx: ctx,
+ sr: reader,
+ SID: s.sid,
+ elementKey: elementKey,
+ window: win,
+ transactionsByKey: make(map[string][]state.Transaction),
+ initialValueByKey: make(map[string]interface{}),
+ initialBagByKey: make(map[string][]interface{}),
+ initialMapValuesByKey: make(map[string]map[string]interface{}),
+ initialMapKeysByKey: make(map[string][]interface{}),
+ readersByKey: make(map[string]io.ReadCloser),
+ appendersByKey: make(map[string]io.Writer),
+ clearersByKey: make(map[string]io.Writer),
+ combineFnsByKey: s.stateIDToCombineFn,
+ codersByKey: s.stateIDToCoder,
+ keyCodersByID: s.stateIDToKeyCoder,
}
return sp, nil
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
index 3774bf71a2c..850a6647a79 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
@@ -87,7 +87,8 @@ const (
URNEnvDocker = "beam:env:docker:v1"
// Userstate Urns.
- URNBagUserState = "beam:user_state:bag:v1"
+ URNBagUserState = "beam:user_state:bag:v1"
+ URNMultiMapUserState = "beam:user_state:multimap:v1"
)
func goCapabilities() []string {
@@ -466,10 +467,19 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
m.requirements[URNRequiresStatefulProcessing] = true
stateSpecs := make(map[string]*pipepb.StateSpec)
for _, ps := range edge.Edge.DoFn.PipelineState() {
- coderID, err := m.coders.Add(edge.Edge.StateCoders[ps.StateKey()])
+ coderID, err := m.coders.Add(edge.Edge.StateCoders[UserStateCoderId(ps)])
if err != nil {
return handleErr(err)
}
+ keyCoderID := ""
+ if c, ok := edge.Edge.StateCoders[UserStateKeyCoderId(ps)]; ok {
+ keyCoderID, err = m.coders.Add(c)
+ if err != nil {
+ return handleErr(err)
+ }
+ } else if ps.StateType() == state.TypeMap {
+ return nil, errors.Errorf("Map type %v must have a key coder type, none detected", ps)
+ }
switch ps.StateType() {
case state.TypeValue:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
@@ -525,6 +535,18 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
Urn: URNBagUserState,
},
}
+ case state.TypeMap:
+ stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
+ Spec: &pipepb.StateSpec_MapSpec{
+ MapSpec: &pipepb.MapStateSpec{
+ KeyCoderId: keyCoderID,
+ ValueCoderId: coderID,
+ },
+ },
+ Protocol: &pipepb.FunctionSpec{
+ Urn: URNMultiMapUserState,
+ },
+ }
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
@@ -1376,3 +1398,13 @@ func UpdateDefaultEnvWorkerType(typeUrn string, pyld []byte, p *pipepb.Pipeline)
}
return errors.Errorf("unable to find dependency with %q role in environment with ID %q,", URNArtifactGoWorkerRole, defaultEnvId)
}
+
+// UserStateCoderId returns the coder id of a user state
+func UserStateCoderId(ps state.PipelineState) string {
+ return fmt.Sprintf("val_%v", ps.StateKey())
+}
+
+// UserStateKeyCoderId returns the key coder id of a user state
+func UserStateKeyCoderId(ps state.PipelineState) string {
+ return fmt.Sprintf("key_%v", ps.StateKey())
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
index 88fc87cb8a5..f10f0d92e84 100644
--- a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
+++ b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
@@ -105,6 +105,46 @@ func (s *ScopedStateReader) OpenBagUserStateClearer(ctx context.Context, id exec
return wr, err
}
+// OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+func (s *ScopedStateReader) OpenMultimapUserStateReader(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error) {
+ rw, err := s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
+ return newMultimapUserStateReader(ch, id, s.instID, userStateID, key, w, mk)
+ })
+ return rw, err
+}
+
+// OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+func (s *ScopedStateReader) OpenMultimapUserStateAppender(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+ wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+ return newMultimapUserStateWriter(ch, id, s.instID, userStateID, key, w, mk, writeTypeAppend)
+ })
+ return wr, err
+}
+
+// OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+func (s *ScopedStateReader) OpenMultimapUserStateClearer(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+ wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+ return newMultimapUserStateWriter(ch, id, s.instID, userStateID, key, w, mk, writeTypeClear)
+ })
+ return wr, err
+}
+
+// OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+func (s *ScopedStateReader) OpenMultimapKeysUserStateReader(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error) {
+ rw, err := s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
+ return newMultimapKeysUserStateReader(ch, id, s.instID, userStateID, key, w)
+ })
+ return rw, err
+}
+
+// OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+func (s *ScopedStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) {
+ wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+ return newMultimapKeysUserStateWriter(ch, id, s.instID, userStateID, key, w, writeTypeClear)
+ })
+ return wr, err
+}
+
// GetSideInputCache returns a pointer to the SideInputCache being used by the SDK harness.
func (s *ScopedStateReader) GetSideInputCache() exec.SideCache {
return s.cache
@@ -273,6 +313,82 @@ func newBagUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructio
}
}
+func newMultimapUserStateReader(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, mk []byte) *stateKeyReader {
+ key := &fnpb.StateKey{
+ Type: &fnpb.StateKey_MultimapUserState_{
+ MultimapUserState: &fnpb.StateKey_MultimapUserState{
+ TransformId: id.PtransformID,
+ UserStateId: userStateID,
+ Window: w,
+ Key: k,
+ MapKey: mk,
+ },
+ },
+ }
+ return &stateKeyReader{
+ instID: instID,
+ key: key,
+ ch: ch,
+ }
+}
+
+func newMultimapUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, mk []byte, wt writeTypeEnum) *stateKeyWriter {
+ key := &fnpb.StateKey{
+ Type: &fnpb.StateKey_MultimapUserState_{
+ MultimapUserState: &fnpb.StateKey_MultimapUserState{
+ TransformId: id.PtransformID,
+ UserStateId: userStateID,
+ Window: w,
+ Key: k,
+ MapKey: mk,
+ },
+ },
+ }
+ return &stateKeyWriter{
+ instID: instID,
+ key: key,
+ ch: ch,
+ writeType: wt,
+ }
+}
+
+func newMultimapKeysUserStateReader(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte) *stateKeyReader {
+ key := &fnpb.StateKey{
+ Type: &fnpb.StateKey_MultimapKeysUserState_{
+ MultimapKeysUserState: &fnpb.StateKey_MultimapKeysUserState{
+ TransformId: id.PtransformID,
+ UserStateId: userStateID,
+ Window: w,
+ Key: k,
+ },
+ },
+ }
+ return &stateKeyReader{
+ instID: instID,
+ key: key,
+ ch: ch,
+ }
+}
+
+func newMultimapKeysUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, wt writeTypeEnum) *stateKeyWriter {
+ key := &fnpb.StateKey{
+ Type: &fnpb.StateKey_MultimapKeysUserState_{
+ MultimapKeysUserState: &fnpb.StateKey_MultimapKeysUserState{
+ TransformId: id.PtransformID,
+ UserStateId: userStateID,
+ Window: w,
+ Key: k,
+ },
+ },
+ }
+ return &stateKeyWriter{
+ instID: instID,
+ key: key,
+ ch: ch,
+ writeType: wt,
+ }
+}
+
func (r *stateKeyReader) Read(buf []byte) (int, error) {
if r.buf == nil {
if r.eof {
diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go
index 8f960cfc739..2e700250c84 100644
--- a/sdks/go/pkg/beam/core/state/state.go
+++ b/sdks/go/pkg/beam/core/state/state.go
@@ -42,6 +42,8 @@ const (
TypeBag TypeEnum = 1
// TypeCombining represents a combining state
TypeCombining TypeEnum = 2
+ // TypeMap represents a map state
+ TypeMap TypeEnum = 3
)
var (
@@ -54,9 +56,10 @@ var (
// Transaction is used to represent a pending state transaction. This should not be manipulated directly;
// it is primarily used for implementations of the Provider interface to talk to the various State objects.
type Transaction struct {
- Key string
- Type TransactionTypeEnum
- Val interface{}
+ Key string
+ Type TransactionTypeEnum
+ MapKey interface{}
+ Val interface{}
}
// Provider represents the DoFn parameter used to get and manipulate pipeline state
@@ -72,12 +75,16 @@ type Provider interface {
AddInputFn(userStateID string) reflectx.Func
MergeAccumulatorsFn(userStateID string) reflectx.Func
ExtractOutputFn(userStateID string) reflectx.Func
+ ReadMapStateValue(userStateID string, key interface{}) (interface{}, []Transaction, error)
+ ReadMapStateKeys(userStateID string) ([]interface{}, []Transaction, error)
+ WriteMapState(val Transaction) error
}
// PipelineState is an interface representing different kinds of PipelineState (currently just state.Value).
// It is primarily meant for Beam packages to use and is probably not useful for most pipeline authors.
type PipelineState interface {
StateKey() string
+ KeyCoderType() reflect.Type
CoderType() reflect.Type
StateType() TypeEnum
}
@@ -137,6 +144,11 @@ func (s Value[T]) StateKey() string {
return s.Key
}
+// KeyCoderType returns nil since Value types aren't keyed.
+func (s Value[T]) KeyCoderType() reflect.Type {
+ return nil
+}
+
// CoderType returns the type of the value state which should be used for a coder.
func (s Value[T]) CoderType() reflect.Type {
var t T
@@ -207,6 +219,11 @@ func (s Bag[T]) StateKey() string {
return s.Key
}
+// KeyCoderType returns nil since Bag types aren't keyed.
+func (s Bag[T]) KeyCoderType() reflect.Type {
+ return nil
+}
+
// CoderType returns the type of the bag state which should be used for a coder.
func (s Bag[T]) CoderType() reflect.Type {
var t T
@@ -343,6 +360,11 @@ func (s Combining[T1, T2, T3]) StateKey() string {
return s.Key
}
+// KeyCoderType returns nil since combining state types aren't keyed.
+func (s Combining[T1, T2, T3]) KeyCoderType() reflect.Type {
+ return nil
+}
+
// CoderType returns the type of the bag state which should be used for a coder.
func (s Combining[T1, T2, T3]) CoderType() reflect.Type {
var t T1
@@ -369,3 +391,129 @@ func MakeCombiningState[T1, T2, T3 any](k string, combiner interface{}) Combinin
combineFn: combiner,
}
}
+
+// Map is used to read and write global pipeline state representing a map.
+// Key represents the key used to lookup this state (not the key of map entries).
+type Map[K comparable, V any] struct {
+ Key string
+}
+
+// Put is used to write a key/value pair to this instance of global map state.
+func (s *Map[K, V]) Put(p Provider, key K, val V) error {
+ return p.WriteMapState(Transaction{
+ Key: s.Key,
+ Type: TransactionTypeSet,
+ MapKey: key,
+ Val: val,
+ })
+}
+
+// Keys is used to read the keys of this map state.
+// When a value is not found, returns an empty list and false.
+func (s *Map[K, V]) Keys(p Provider) ([]K, bool, error) {
+ // This replays any writes that have happened to this value since we last read
+ // For more detail, see "State Transactionality" below for buffered transactions
+ initialValue, bufferedTransactions, err := p.ReadMapStateKeys(s.Key)
+ if err != nil {
+ return []K{}, false, err
+ }
+ cur := []K{}
+ for _, v := range initialValue {
+ cur = append(cur, v.(K))
+ }
+ for _, t := range bufferedTransactions {
+ switch t.Type {
+ case TransactionTypeSet:
+ seen := false
+ mk := t.MapKey.(K)
+ for _, k := range cur {
+ if k == mk {
+ seen = true
+ }
+ }
+ if !seen {
+ cur = append(cur, mk)
+ }
+ case TransactionTypeClear:
+ if t.MapKey == nil {
+ cur = []K{}
+ } else {
+ k := t.MapKey.(K)
+ for idx, v := range cur {
+ if v == k {
+ // Remove this key since its been cleared
+ cur[idx] = cur[len(cur)-1]
+ cur = cur[:len(cur)-1]
+ break
+ }
+ }
+ }
+ }
+ }
+ if len(cur) == 0 {
+ return cur, false, nil
+ }
+ return cur, true, nil
+}
+
+// Get is used to read a value given a key.
+// When a value is not found, returns the 0 value and false.
+func (s *Map[K, V]) Get(p Provider, key K) (V, bool, error) {
+ // This replays any writes that have happened to this value since we last read
+ // For more detail, see "State Transactionality" below for buffered transactions
+ cur, bufferedTransactions, err := p.ReadMapStateValue(s.Key, key)
+ if err != nil {
+ var val V
+ return val, false, err
+ }
+ for _, t := range bufferedTransactions {
+ switch t.Type {
+ case TransactionTypeSet:
+ if t.MapKey.(K) == key {
+ cur = t.Val
+ }
+ case TransactionTypeClear:
+ if t.MapKey == nil || t.MapKey.(K) == key {
+ cur = nil
+ }
+ }
+ }
+ if cur == nil {
+ var val V
+ return val, false, nil
+ }
+ return cur.(V), true, nil
+}
+
+// StateKey returns the key for this pipeline state entry.
+func (s Map[K, V]) StateKey() string {
+ if s.Key == "" {
+ // TODO(#22736) - infer the state from the member variable name during pipeline construction.
+ panic("Value state exists on struct but has not been initialized with a key.")
+ }
+ return s.Key
+}
+
+// KeyCoderType returns the type of the value state which should be used for a coder for map keys.
+func (s Map[K, V]) KeyCoderType() reflect.Type {
+ var k K
+ return reflect.TypeOf(k)
+}
+
+// CoderType returns the type of the value state which should be used for a coder for map values.
+func (s Map[K, V]) CoderType() reflect.Type {
+ var v V
+ return reflect.TypeOf(v)
+}
+
+// StateType returns the type of the state (in this case always Map).
+func (s Map[K, V]) StateType() TypeEnum {
+ return TypeMap
+}
+
+// MakeValueState is a factory function to create an instance of ValueState with the given key.
+func MakeMapState[K comparable, V any](k string) Map[K, V] {
+ return Map[K, V]{
+ Key: k,
+ }
+}
diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go
index 9ee99d790cf..d95677ce3a0 100644
--- a/sdks/go/pkg/beam/core/state/state_test.go
+++ b/sdks/go/pkg/beam/core/state/state_test.go
@@ -29,6 +29,7 @@ var (
type fakeProvider struct {
initialState map[string]interface{}
initialBagState map[string][]interface{}
+ initialMapState map[string]map[string]interface{}
transactions map[string][]Transaction
err map[string]error
createAccumForKey map[string]bool
@@ -116,6 +117,46 @@ func (s *fakeProvider) ExtractOutputFn(userStateID string) reflectx.Func {
return nil
}
+func (s *fakeProvider) ReadMapStateValue(userStateID string, key interface{}) (interface{}, []Transaction, error) {
+ keyString := key.(string)
+ if err, ok := s.err[userStateID]; ok {
+ return nil, nil, err
+ }
+ base := s.initialMapState[userStateID][keyString]
+ trans, ok := s.transactions[userStateID]
+ if !ok {
+ trans = []Transaction{}
+ }
+ return base, trans, nil
+}
+
+func (s *fakeProvider) ReadMapStateKeys(userStateID string) ([]interface{}, []Transaction, error) {
+ if err, ok := s.err[userStateID]; ok {
+ return nil, nil, err
+ }
+ base := s.initialMapState[userStateID]
+ keys := make([]interface{}, len(base))
+ i := 0
+ for k := range base {
+ keys[i] = k
+ i++
+ }
+ trans, ok := s.transactions[userStateID]
+ if !ok {
+ trans = []Transaction{}
+ }
+ return keys, trans, nil
+}
+
+func (s *fakeProvider) WriteMapState(val Transaction) error {
+ if transactions, ok := s.transactions[val.Key]; ok {
+ s.transactions[val.Key] = append(transactions, val)
+ } else {
+ s.transactions[val.Key] = []Transaction{val}
+ }
+ return nil
+}
+
func TestValueRead(t *testing.T) {
is := make(map[string]interface{})
ts := make(map[string][]Transaction)
@@ -475,3 +516,171 @@ func TestCombiningAdd(t *testing.T) {
}
}
}
+
+func TestMapGet(t *testing.T) {
+ is := make(map[string]interface{})
+ im := make(map[string]map[string]interface{})
+ ts := make(map[string][]Transaction)
+ es := make(map[string]error)
+ ca := make(map[string]bool)
+ eo := make(map[string]bool)
+ ts["no_transactions"] = nil
+ im["basic_set"] = map[string]interface{}{"foo": 2}
+ ts["basic_set"] = []Transaction{{Key: "basic_set", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "basic_set", Type: TransactionTypeSet, Val: 1, MapKey: "bar"}}
+ im["basic_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+ ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+ im["set_then_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+ ts["set_then_clear"] = []Transaction{{Key: "set_then_clear", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "set_then_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+ im["err"] = map[string]interface{}{"foo": 2}
+ ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeSet, Val: 3}}
+ es["err"] = errFake
+
+ f := fakeProvider{
+ initialMapState: im,
+ initialState: is,
+ transactions: ts,
+ err: es,
+ createAccumForKey: ca,
+ extractOutForKey: eo,
+ }
+
+ var tests = []struct {
+ vs Map[string, int]
+ foo int
+ bar int
+ fooOk bool
+ barOk bool
+ err error
+ }{
+ {MakeMapState[string, int]("no_transactions"), 0, 0, false, false, nil},
+ {MakeMapState[string, int]("basic_set"), 3, 1, true, true, nil},
+ {MakeMapState[string, int]("basic_clear"), 0, 1, false, true, nil},
+ {MakeMapState[string, int]("set_then_clear"), 0, 1, false, true, nil},
+ {MakeMapState[string, int]("err"), 0, 0, false, false, errFake},
+ }
+
+ for _, tt := range tests {
+ val, ok, err := tt.vs.Get(&f, "foo")
+ if err != nil && tt.err == nil {
+ t.Errorf("Map.Get() returned error %v for state key %v and map key foo when it shouldn't have", err, tt.vs.Key)
+ } else if err == nil && tt.err != nil {
+ t.Errorf("Map.Get() returned no error for state key %v and map key foo when it should have returned %v", tt.vs.Key, err)
+ } else if ok && !tt.fooOk {
+ t.Errorf("Map.Get() returned a value %v for state key %v and map key foo when it shouldn't have", val, tt.vs.Key)
+ } else if !ok && tt.fooOk {
+ t.Errorf("Map.Get() didn't return a value for state key %v and map key foo when it should have returned %v", tt.vs.Key, tt.foo)
+ } else if val != tt.foo {
+ t.Errorf("Map.Get()=%v, want %v for state key %v and map key foo", val, tt.foo, tt.vs.Key)
+ }
+ val, ok, err = tt.vs.Get(&f, "bar")
+ if err != nil && tt.err == nil {
+ t.Errorf("Map.Get() returned error %v for state key %v and map key bar when it shouldn't have", err, tt.vs.Key)
+ } else if err == nil && tt.err != nil {
+ t.Errorf("Map.Get() returned no error for state key %v and map key bar when it should have returned %v", tt.vs.Key, err)
+ } else if ok && !tt.barOk {
+ t.Errorf("Map.Get() returned a value %v for state key %v and map key bar when it shouldn't have", val, tt.vs.Key)
+ } else if !ok && tt.barOk {
+ t.Errorf("Map.Get() didn't return a value for state key %v and map key bar when it should have returned %v", tt.vs.Key, tt.bar)
+ } else if val != tt.bar {
+ t.Errorf("Map.Get()=%v, want %v for state key %v and map key bar", val, tt.bar, tt.vs.Key)
+ }
+ }
+}
+
+func TestMapKeys(t *testing.T) {
+ is := make(map[string]interface{})
+ im := make(map[string]map[string]interface{})
+ ts := make(map[string][]Transaction)
+ es := make(map[string]error)
+ ca := make(map[string]bool)
+ eo := make(map[string]bool)
+ ts["no_transactions"] = nil
+ im["basic_set"] = map[string]interface{}{"foo": 2}
+ ts["basic_set"] = []Transaction{{Key: "basic_set", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "basic_set", Type: TransactionTypeSet, Val: 1, MapKey: "bar"}}
+ im["basic_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+ ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+ im["set_then_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+ ts["set_then_clear"] = []Transaction{{Key: "set_then_clear", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "set_then_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+ im["err"] = map[string]interface{}{"foo": 2}
+ ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeSet, Val: 3}}
+ es["err"] = errFake
+
+ f := fakeProvider{
+ initialMapState: im,
+ initialState: is,
+ transactions: ts,
+ err: es,
+ createAccumForKey: ca,
+ extractOutForKey: eo,
+ }
+
+ var tests = []struct {
+ vs Map[string, int]
+ keys []string
+ ok bool
+ err error
+ }{
+ {MakeMapState[string, int]("no_transactions"), []string{}, false, nil},
+ {MakeMapState[string, int]("basic_set"), []string{"foo", "bar"}, true, nil},
+ {MakeMapState[string, int]("basic_clear"), []string{"bar"}, true, nil},
+ {MakeMapState[string, int]("set_then_clear"), []string{"bar"}, true, nil},
+ {MakeMapState[string, int]("err"), []string{}, false, errFake},
+ }
+
+ for _, tt := range tests {
+ val, ok, err := tt.vs.Keys(&f)
+ if err != nil && tt.err == nil {
+ t.Errorf("Map.Keys() returned error %v for state key %v when it shouldn't have", err, tt.vs.Key)
+ } else if err == nil && tt.err != nil {
+ t.Errorf("Map.Keys() returned no error for state key %v when it should have returned %v", tt.vs.Key, err)
+ } else if ok && !tt.ok {
+ t.Errorf("Map.Keys() returned a value %v for state key %v when it shouldn't have", val, tt.vs.Key)
+ } else if len(val) != len(tt.keys) {
+ t.Errorf("Map.Keys()=%v, want %v for state key %v", val, tt.keys, tt.vs.Key)
+ } else {
+ eq := true
+ for idx, v := range val {
+ if v != tt.keys[idx] {
+ eq = false
+ }
+ }
+ if !eq {
+ t.Errorf("Map.Keys()=%v, want %v for state key %v", val, tt.keys, tt.vs.Key)
+ }
+ }
+ }
+}
+
+func TestMapPut(t *testing.T) {
+ var tests = []struct {
+ writes []int
+ val int
+ ok bool
+ }{
+ {[]int{}, 0, false},
+ {[]int{3}, 3, true},
+ {[]int{1, 5}, 5, true},
+ }
+
+ for _, tt := range tests {
+ f := fakeProvider{
+ initialState: make(map[string]interface{}),
+ transactions: make(map[string][]Transaction),
+ err: make(map[string]error),
+ }
+ vs := MakeMapState[string, int]("vs")
+ for _, val := range tt.writes {
+ vs.Put(&f, "foo", val)
+ }
+ val, ok, err := vs.Get(&f, "foo")
+ if err != nil {
+ t.Errorf("Map.Get(\"foo\") returned error %v when it shouldn't have after writing: %v", err, tt.writes)
+ } else if ok && !tt.ok {
+ t.Errorf("Map.Get(\"foo\") returned a value %v when it shouldn't have after writing: %v", val, tt.writes)
+ } else if !ok && tt.ok {
+ t.Errorf("Map.Get(\"foo\") didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes)
+ } else if val != tt.val {
+ t.Errorf("Map.Get(\"foo\")=%v, want %v after writing: %v", val, tt.val, tt.writes)
+ }
+ }
+}
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index dcdb3e74d1e..47c4dee0fd7 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -23,6 +23,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+ "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
@@ -101,7 +102,15 @@ func TryParDo(s Scope, dofn interface{}, col PCollection, opts ...Option) ([]PCo
if err != nil {
return nil, addParDoCtx(err, s)
}
- edge.StateCoders[ps.StateKey()] = c
+ edge.StateCoders[graphx.UserStateCoderId(ps)] = c
+ if kct := ps.KeyCoderType(); kct != nil {
+ kT := typex.New(kct)
+ kc, err := inferCoder(kT)
+ if err != nil {
+ return nil, addParDoCtx(err, s)
+ }
+ edge.StateCoders[graphx.UserStateKeyCoderId(ps)] = kc
+ }
}
}