You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by da...@apache.org on 2022/08/26 18:50:19 UTC

[beam] branch master updated: Add map state in the Go Sdk (#22897)

This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c82583d6ac Add map state in the Go Sdk (#22897)
0c82583d6ac is described below

commit 0c82583d6ac5c60490d60dd5c28419f2ee524103
Author: Danny McCormick <da...@google.com>
AuthorDate: Fri Aug 26 14:50:08 2022 -0400

    Add map state in the Go Sdk (#22897)
    
    * Add map state in the Go Sdk
    
    * Remove unused function for now
    
    * Comment fixes
    
    * Update sdks/go/pkg/beam/core/runtime/graphx/translate.go
    
    Co-authored-by: Ritesh Ghorse <ri...@gmail.com>
    
    Co-authored-by: Ritesh Ghorse <ri...@gmail.com>
---
 sdks/go/pkg/beam/core/graph/fn.go                  |   4 +-
 sdks/go/pkg/beam/core/graph/fn_test.go             |   9 +
 sdks/go/pkg/beam/core/runtime/exec/data.go         |  10 +
 .../pkg/beam/core/runtime/exec/sideinput_test.go   |  25 +++
 sdks/go/pkg/beam/core/runtime/exec/translate.go    |  14 +-
 sdks/go/pkg/beam/core/runtime/exec/userstate.go    | 233 ++++++++++++++++++---
 sdks/go/pkg/beam/core/runtime/graphx/translate.go  |  36 +++-
 sdks/go/pkg/beam/core/runtime/harness/statemgr.go  | 116 ++++++++++
 sdks/go/pkg/beam/core/state/state.go               | 154 +++++++++++++-
 sdks/go/pkg/beam/core/state/state_test.go          | 209 ++++++++++++++++++
 sdks/go/pkg/beam/pardo.go                          |  11 +-
 11 files changed, 778 insertions(+), 43 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go
index 837c8a41377..a78ae73db03 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
 					"unique per DoFn", k, orig, s)
 			}
 			t := s.StateType()
-			if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining {
+			if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining && t != state.TypeMap {
 				err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
 				return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
-					"types are state.Value, state.Combining, and state.Bag", t, s)
+					"types are state.Value, state.Combining, state.Bag, and state.Map", t, s)
 			}
 			stateKeys[k] = s
 		}
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go
index 19647d88cbb..0d9a9e744d2 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -58,6 +58,7 @@ func TestNewDoFn(t *testing.T) {
 			{dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int {
 				return a * b
 			})}, opt: NumMainInputs(MainKv)},
+			{dfn: &GoodStatefulDoFn4{State1: state.MakeMapState[string, int]("state1")}, opt: NumMainInputs(MainKv)},
 		}
 
 		for _, test := range tests {
@@ -1107,6 +1108,14 @@ func (fn *GoodStatefulDoFn3) ProcessElement(state.Provider, int, int) int {
 	return 0
 }
 
+type GoodStatefulDoFn4 struct {
+	State1 state.Map[string, int]
+}
+
+func (fn *GoodStatefulDoFn4) ProcessElement(state.Provider, int, int) int {
+	return 0
+}
+
 // Examples of incorrect SDF signatures.
 // Examples with missing methods.
 
diff --git a/sdks/go/pkg/beam/core/runtime/exec/data.go b/sdks/go/pkg/beam/core/runtime/exec/data.go
index b6b4727d742..fdc1e368a52 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/data.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/data.go
@@ -77,6 +77,16 @@ type StateReader interface {
 	OpenBagUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
 	// OpenBagUserStateClearer opens a byte stream for clearing user bag state.
 	OpenBagUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
+	// OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+	OpenMultimapUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error)
+	// OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+	OpenMultimapUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error)
+	// OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+	OpenMultimapUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error)
+	// OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+	OpenMultimapKeysUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error)
+	// OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+	OpenMultimapKeysUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error)
 	// GetSideInputCache returns the SideInputCache being used at the harness level.
 	GetSideInputCache() SideCache
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
index 06098fd8ea5..d045c9e36dd 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sideinput_test.go
@@ -148,6 +148,31 @@ func (t *testStateReader) OpenBagUserStateClearer(ctx context.Context, id Stream
 	return nil, nil
 }
 
+// OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+func (s *testStateReader) OpenMultimapUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error) {
+	return nil, nil
+}
+
+// OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+func (s *testStateReader) OpenMultimapUserStateAppender(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+	return nil, nil
+}
+
+// OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+func (s *testStateReader) OpenMultimapUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+	return nil, nil
+}
+
+// OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+func (s *testStateReader) OpenMultimapKeysUserStateReader(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error) {
+	return nil, nil
+}
+
+// OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+func (s *testStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context, id StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) {
+	return nil, nil
+}
+
 func (t *testStateReader) GetSideInputCache() SideCache {
 	return &testSideCache{}
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/translate.go b/sdks/go/pkg/beam/core/runtime/exec/translate.go
index fc4844010fa..abae7d2d080 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/translate.go
@@ -467,9 +467,11 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
 
 					if len(userState) > 0 {
 						stateIDToCoder := make(map[string]*coder.Coder)
+						stateIDToKeyCoder := make(map[string]*coder.Coder)
 						stateIDToCombineFn := make(map[string]*graph.CombineFn)
 						for key, spec := range userState {
 							var cID string
+							var kcID string
 							if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
 								cID = rmw.CoderId
 							} else if bs := spec.GetBagSpec(); bs != nil {
@@ -490,11 +492,21 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
 									return nil, err
 								}
 								stateIDToCombineFn[key] = cfn
+							} else if ms := spec.GetMapSpec(); ms != nil {
+								cID = ms.ValueCoderId
+								kcID = ms.KeyCoderId
 							}
 							c, err := b.coders.Coder(cID)
 							if err != nil {
 								return nil, err
 							}
+							if kcID != "" {
+								kc, err := b.coders.Coder(kcID)
+								if err != nil {
+									return nil, err
+								}
+								stateIDToKeyCoder[key] = kc
+							}
 							stateIDToCoder[key] = c
 							sid := StreamID{
 								Port:         Port{URL: b.desc.GetStateApiServiceDescriptor().GetUrl()},
@@ -505,7 +517,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
 							if err != nil {
 								return nil, err
 							}
-							n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToCombineFn)
+							n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToKeyCoder, stateIDToCombineFn)
 						}
 					}
 
diff --git a/sdks/go/pkg/beam/core/runtime/exec/userstate.go b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
index 2587530c838..a980d2bb3bb 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/userstate.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/userstate.go
@@ -16,6 +16,7 @@
 package exec
 
 import (
+	"bytes"
 	"context"
 	"fmt"
 	"io"
@@ -34,21 +35,24 @@ type stateProvider struct {
 	elementKey []byte
 	window     []byte
 
-	transactionsByKey map[string][]state.Transaction
-	initialValueByKey map[string]interface{}
-	initialBagByKey   map[string][]interface{}
-	readersByKey      map[string]io.ReadCloser
-	appendersByKey    map[string]io.Writer
-	clearersByKey     map[string]io.Writer
-	codersByKey       map[string]*coder.Coder
-	combineFnsByKey   map[string]*graph.CombineFn
+	transactionsByKey     map[string][]state.Transaction
+	initialValueByKey     map[string]interface{}
+	initialBagByKey       map[string][]interface{}
+	initialMapValuesByKey map[string]map[string]interface{}
+	initialMapKeysByKey   map[string][]interface{}
+	readersByKey          map[string]io.ReadCloser
+	appendersByKey        map[string]io.Writer
+	clearersByKey         map[string]io.Writer
+	codersByKey           map[string]*coder.Coder
+	keyCodersByID         map[string]*coder.Coder
+	combineFnsByKey       map[string]*graph.CombineFn
 }
 
 // ReadValueState reads a value state from the State API
 func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state.Transaction, error) {
 	initialValue, ok := s.initialValueByKey[userStateID]
 	if !ok {
-		rw, err := s.getReader(userStateID)
+		rw, err := s.getBagReader(userStateID)
 		if err != nil {
 			return nil, nil, err
 		}
@@ -75,13 +79,13 @@ func (s *stateProvider) ReadValueState(userStateID string) (interface{}, []state
 // WriteValueState writes a value state to the State API
 // For value states, this is done by clearing a bag state and writing a value to it.
 func (s *stateProvider) WriteValueState(val state.Transaction) error {
-	cl, err := s.getClearer(val.Key)
+	cl, err := s.getBagClearer(val.Key)
 	if err != nil {
 		return err
 	}
 	cl.Write([]byte{})
 
-	ap, err := s.getAppender(val.Key)
+	ap, err := s.getBagAppender(val.Key)
 	if err != nil {
 		return err
 	}
@@ -106,12 +110,12 @@ func (s *stateProvider) WriteValueState(val state.Transaction) error {
 	return nil
 }
 
-// ReadBagState reads a ReadBagState state from the State API
+// ReadBagState reads a bag state from the State API
 func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state.Transaction, error) {
 	initialValue, ok := s.initialBagByKey[userStateID]
 	if !ok {
 		initialValue = []interface{}{}
-		rw, err := s.getReader(userStateID)
+		rw, err := s.getBagReader(userStateID)
 		if err != nil {
 			return nil, nil, err
 		}
@@ -136,10 +140,9 @@ func (s *stateProvider) ReadBagState(userStateID string) ([]interface{}, []state
 	return initialValue, transactions, nil
 }
 
-// WriteValueState writes a value state to the State API
-// For value states, this is done by clearing a bag state and writing a value to it.
+// WriteBagState writes a bag state to the State API
 func (s *stateProvider) WriteBagState(val state.Transaction) error {
-	ap, err := s.getAppender(val.Key)
+	ap, err := s.getBagAppender(val.Key)
 	if err != nil {
 		return err
 	}
@@ -162,6 +165,105 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error {
 	return nil
 }
 
+// ReadMapStateValue reads a value from the map state for a given key.
+func (s *stateProvider) ReadMapStateValue(userStateID string, key interface{}) (interface{}, []state.Transaction, error) {
+	_, ok := s.initialMapValuesByKey[userStateID]
+	if !ok {
+		s.initialMapValuesByKey[userStateID] = make(map[string]interface{})
+	}
+	b, err := s.encodeKey(userStateID, key)
+	if err != nil {
+		return nil, nil, err
+	}
+	initialValue, ok := s.initialMapValuesByKey[userStateID][string(b)]
+	if !ok {
+		rw, err := s.getMultiMapReader(userStateID, key)
+		if err != nil {
+			return nil, nil, err
+		}
+		dec := MakeElementDecoder(coder.SkipW(s.codersByKey[userStateID]))
+		resp, err := dec.Decode(rw)
+		if err != nil && err != io.EOF {
+			return nil, nil, err
+		}
+		if resp == nil {
+			return nil, []state.Transaction{}, nil
+		}
+		initialValue = resp.Elm
+		s.initialValueByKey[userStateID] = initialValue
+	}
+
+	transactions, ok := s.transactionsByKey[userStateID]
+	if !ok {
+		transactions = []state.Transaction{}
+	}
+
+	return initialValue, transactions, nil
+}
+
+// ReadMapStateKeys reads all the keys in a map state.
+func (s *stateProvider) ReadMapStateKeys(userStateID string) ([]interface{}, []state.Transaction, error) {
+	initialValue, ok := s.initialMapKeysByKey[userStateID]
+	if !ok {
+		initialValue = []interface{}{}
+		rw, err := s.getMultiMapKeyReader(userStateID)
+		if err != nil {
+			return nil, nil, err
+		}
+		dec := MakeElementDecoder(coder.SkipW(s.keyCodersByID[userStateID]))
+		for err == nil {
+			var resp *FullValue
+			resp, err = dec.Decode(rw)
+			if err == nil {
+				initialValue = append(initialValue, resp.Elm)
+			} else if err != io.EOF {
+				return nil, nil, err
+			}
+		}
+		s.initialMapKeysByKey[userStateID] = initialValue
+	}
+
+	transactions, ok := s.transactionsByKey[userStateID]
+	if !ok {
+		transactions = []state.Transaction{}
+	}
+
+	return initialValue, transactions, nil
+}
+
+// WriteMapState writes a key value pair to the global map state.
+func (s *stateProvider) WriteMapState(val state.Transaction) error {
+	cl, err := s.getMultiMapClearer(val.Key, val.MapKey)
+	if err != nil {
+		return err
+	}
+	cl.Write([]byte{})
+
+	ap, err := s.getMultiMapAppender(val.Key, val.MapKey)
+	if err != nil {
+		return err
+	}
+	fv := FullValue{Elm: val.Val}
+	// TODO(#22736) - consider caching this a proprty of stateProvider
+	enc := MakeElementEncoder(coder.SkipW(s.codersByKey[val.Key]))
+	err = enc.Encode(&fv, ap)
+	if err != nil {
+		return err
+	}
+
+	// TODO(#22736) - optimize this a bit once all state types are added. In the case of sets/clears,
+	// we can remove the transactions. We can also consider combining other transactions on read (or sooner)
+	// so that we don't need to use as much memory/time replaying transactions.
+	if transactions, ok := s.transactionsByKey[val.Key]; ok {
+		transactions = append(transactions, val)
+		s.transactionsByKey[val.Key] = transactions
+	} else {
+		s.transactionsByKey[val.Key] = []state.Transaction{val}
+	}
+
+	return nil
+}
+
 func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
 	a := s.combineFnsByKey[userStateID]
 	if ca := a.CreateAccumulatorFn(); ca != nil {
@@ -196,7 +298,7 @@ func (s *stateProvider) ExtractOutputFn(userStateID string) reflectx.Func {
 	return nil
 }
 
-func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
+func (s *stateProvider) getBagReader(userStateID string) (io.ReadCloser, error) {
 	if r, ok := s.readersByKey[userStateID]; ok {
 		return r, nil
 	}
@@ -208,7 +310,7 @@ func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
 	return s.readersByKey[userStateID], nil
 }
 
-func (s *stateProvider) getAppender(userStateID string) (io.Writer, error) {
+func (s *stateProvider) getBagAppender(userStateID string) (io.Writer, error) {
 	if w, ok := s.appendersByKey[userStateID]; ok {
 		return w, nil
 	}
@@ -220,7 +322,7 @@ func (s *stateProvider) getAppender(userStateID string) (io.Writer, error) {
 	return s.appendersByKey[userStateID], nil
 }
 
-func (s *stateProvider) getClearer(userStateID string) (io.Writer, error) {
+func (s *stateProvider) getBagClearer(userStateID string) (io.Writer, error) {
 	if w, ok := s.clearersByKey[userStateID]; ok {
 		return w, nil
 	}
@@ -232,6 +334,65 @@ func (s *stateProvider) getClearer(userStateID string) (io.Writer, error) {
 	return s.clearersByKey[userStateID], nil
 }
 
+func (s *stateProvider) getMultiMapReader(userStateID string, key interface{}) (io.ReadCloser, error) {
+	ek, err := s.encodeKey(userStateID, key)
+	if err != nil {
+		return nil, err
+	}
+	r, err := s.sr.OpenMultimapUserStateReader(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+	if err != nil {
+		return nil, err
+	}
+	return r, nil
+}
+
+func (s *stateProvider) getMultiMapAppender(userStateID string, key interface{}) (io.Writer, error) {
+	ek, err := s.encodeKey(userStateID, key)
+	if err != nil {
+		return nil, err
+	}
+	w, err := s.sr.OpenMultimapUserStateAppender(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+	if err != nil {
+		return nil, err
+	}
+	return w, nil
+}
+
+func (s *stateProvider) getMultiMapClearer(userStateID string, key interface{}) (io.Writer, error) {
+	ek, err := s.encodeKey(userStateID, key)
+	if err != nil {
+		return nil, err
+	}
+	w, err := s.sr.OpenMultimapUserStateClearer(s.ctx, s.SID, userStateID, s.elementKey, s.window, ek)
+	if err != nil {
+		return nil, err
+	}
+	return w, nil
+}
+
+func (s *stateProvider) getMultiMapKeyReader(userStateID string) (io.ReadCloser, error) {
+	if r, ok := s.readersByKey[userStateID]; ok {
+		return r, nil
+	}
+	r, err := s.sr.OpenMultimapKeysUserStateReader(s.ctx, s.SID, userStateID, s.elementKey, s.window)
+	if err != nil {
+		return nil, err
+	}
+	s.readersByKey[userStateID] = r
+	return s.readersByKey[userStateID], nil
+}
+
+func (s *stateProvider) encodeKey(userStateID string, key interface{}) ([]byte, error) {
+	fv := FullValue{Elm: key}
+	enc := MakeElementEncoder(coder.SkipW(s.keyCodersByID[userStateID]))
+	var b bytes.Buffer
+	err := enc.Encode(&fv, &b)
+	if err != nil {
+		return nil, err
+	}
+	return b.Bytes(), nil
+}
+
 // UserStateAdapter provides a state provider to be used for user state.
 type UserStateAdapter interface {
 	NewStateProvider(ctx context.Context, reader StateReader, w typex.Window, element interface{}) (stateProvider, error)
@@ -242,13 +403,14 @@ type userStateAdapter struct {
 	wc                 WindowEncoder
 	kc                 ElementEncoder
 	stateIDToCoder     map[string]*coder.Coder
+	stateIDToKeyCoder  map[string]*coder.Coder
 	stateIDToCombineFn map[string]*graph.CombineFn
 	c                  *coder.Coder
 }
 
 // NewUserStateAdapter returns a user state adapter for the given StreamID and coder.
 // It expects a W<V> or W<KV<K,V>> coder, because the protocol requires windowing information.
-func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
+func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToKeyCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
 	if !coder.IsW(c) {
 		panic(fmt.Sprintf("expected WV coder for user state %v: %v", sid, c))
 	}
@@ -258,7 +420,7 @@ func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string
 	if coder.IsKV(coder.SkipW(c)) {
 		kc = MakeElementEncoder(coder.SkipW(c).Components[0])
 	}
-	return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToCombineFn: stateIDToCombineFn}
+	return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToKeyCoder: stateIDToKeyCoder, stateIDToCombineFn: stateIDToCombineFn}
 }
 
 // NewStateProvider creates a stateProvider with the ability to talk to the state API.
@@ -276,19 +438,22 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
 		return stateProvider{}, err
 	}
 	sp := stateProvider{
-		ctx:               ctx,
-		sr:                reader,
-		SID:               s.sid,
-		elementKey:        elementKey,
-		window:            win,
-		transactionsByKey: make(map[string][]state.Transaction),
-		initialValueByKey: make(map[string]interface{}),
-		initialBagByKey:   make(map[string][]interface{}),
-		readersByKey:      make(map[string]io.ReadCloser),
-		appendersByKey:    make(map[string]io.Writer),
-		clearersByKey:     make(map[string]io.Writer),
-		combineFnsByKey:   s.stateIDToCombineFn,
-		codersByKey:       s.stateIDToCoder,
+		ctx:                   ctx,
+		sr:                    reader,
+		SID:                   s.sid,
+		elementKey:            elementKey,
+		window:                win,
+		transactionsByKey:     make(map[string][]state.Transaction),
+		initialValueByKey:     make(map[string]interface{}),
+		initialBagByKey:       make(map[string][]interface{}),
+		initialMapValuesByKey: make(map[string]map[string]interface{}),
+		initialMapKeysByKey:   make(map[string][]interface{}),
+		readersByKey:          make(map[string]io.ReadCloser),
+		appendersByKey:        make(map[string]io.Writer),
+		clearersByKey:         make(map[string]io.Writer),
+		combineFnsByKey:       s.stateIDToCombineFn,
+		codersByKey:           s.stateIDToCoder,
+		keyCodersByID:         s.stateIDToKeyCoder,
 	}
 
 	return sp, nil
diff --git a/sdks/go/pkg/beam/core/runtime/graphx/translate.go b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
index 3774bf71a2c..850a6647a79 100644
--- a/sdks/go/pkg/beam/core/runtime/graphx/translate.go
+++ b/sdks/go/pkg/beam/core/runtime/graphx/translate.go
@@ -87,7 +87,8 @@ const (
 	URNEnvDocker   = "beam:env:docker:v1"
 
 	// Userstate Urns.
-	URNBagUserState = "beam:user_state:bag:v1"
+	URNBagUserState      = "beam:user_state:bag:v1"
+	URNMultiMapUserState = "beam:user_state:multimap:v1"
 )
 
 func goCapabilities() []string {
@@ -466,10 +467,19 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
 			m.requirements[URNRequiresStatefulProcessing] = true
 			stateSpecs := make(map[string]*pipepb.StateSpec)
 			for _, ps := range edge.Edge.DoFn.PipelineState() {
-				coderID, err := m.coders.Add(edge.Edge.StateCoders[ps.StateKey()])
+				coderID, err := m.coders.Add(edge.Edge.StateCoders[UserStateCoderId(ps)])
 				if err != nil {
 					return handleErr(err)
 				}
+				keyCoderID := ""
+				if c, ok := edge.Edge.StateCoders[UserStateKeyCoderId(ps)]; ok {
+					keyCoderID, err = m.coders.Add(c)
+					if err != nil {
+						return handleErr(err)
+					}
+				} else if ps.StateType() == state.TypeMap {
+					return nil, errors.Errorf("Map type %v must have a key coder type, none detected", ps)
+				}
 				switch ps.StateType() {
 				case state.TypeValue:
 					stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
@@ -525,6 +535,18 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
 							Urn: URNBagUserState,
 						},
 					}
+				case state.TypeMap:
+					stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
+						Spec: &pipepb.StateSpec_MapSpec{
+							MapSpec: &pipepb.MapStateSpec{
+								KeyCoderId:   keyCoderID,
+								ValueCoderId: coderID,
+							},
+						},
+						Protocol: &pipepb.FunctionSpec{
+							Urn: URNMultiMapUserState,
+						},
+					}
 				default:
 					return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
 				}
@@ -1376,3 +1398,13 @@ func UpdateDefaultEnvWorkerType(typeUrn string, pyld []byte, p *pipepb.Pipeline)
 	}
 	return errors.Errorf("unable to find dependency with %q role in environment with ID %q,", URNArtifactGoWorkerRole, defaultEnvId)
 }
+
+// UserStateCoderId returns the coder id of a user state
+func UserStateCoderId(ps state.PipelineState) string {
+	return fmt.Sprintf("val_%v", ps.StateKey())
+}
+
+// UserStateKeyCoderId returns the key coder id of a user state
+func UserStateKeyCoderId(ps state.PipelineState) string {
+	return fmt.Sprintf("key_%v", ps.StateKey())
+}
diff --git a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
index 88fc87cb8a5..f10f0d92e84 100644
--- a/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
+++ b/sdks/go/pkg/beam/core/runtime/harness/statemgr.go
@@ -105,6 +105,46 @@ func (s *ScopedStateReader) OpenBagUserStateClearer(ctx context.Context, id exec
 	return wr, err
 }
 
+// OpenMultimapUserStateReader opens a byte stream for reading user multimap state.
+func (s *ScopedStateReader) OpenMultimapUserStateReader(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.ReadCloser, error) {
+	rw, err := s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
+		return newMultimapUserStateReader(ch, id, s.instID, userStateID, key, w, mk)
+	})
+	return rw, err
+}
+
+// OpenMultimapUserStateAppender opens a byte stream for appending user multimap state.
+func (s *ScopedStateReader) OpenMultimapUserStateAppender(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+	wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+		return newMultimapUserStateWriter(ch, id, s.instID, userStateID, key, w, mk, writeTypeAppend)
+	})
+	return wr, err
+}
+
+// OpenMultimapUserStateClearer opens a byte stream for clearing user multimap state by key.
+func (s *ScopedStateReader) OpenMultimapUserStateClearer(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte, mk []byte) (io.Writer, error) {
+	wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+		return newMultimapUserStateWriter(ch, id, s.instID, userStateID, key, w, mk, writeTypeClear)
+	})
+	return wr, err
+}
+
+// OpenMultimapKeysUserStateReader opens a byte stream for reading the keys of user multimap state.
+func (s *ScopedStateReader) OpenMultimapKeysUserStateReader(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte) (io.ReadCloser, error) {
+	rw, err := s.openReader(ctx, id, func(ch *StateChannel) *stateKeyReader {
+		return newMultimapKeysUserStateReader(ch, id, s.instID, userStateID, key, w)
+	})
+	return rw, err
+}
+
+// OpenMultimapKeysUserStateClearer opens a byte stream for clearing all keys of user multimap state.
+func (s *ScopedStateReader) OpenMultimapKeysUserStateClearer(ctx context.Context, id exec.StreamID, userStateID string, key []byte, w []byte) (io.Writer, error) {
+	wr, err := s.openWriter(ctx, id, func(ch *StateChannel) *stateKeyWriter {
+		return newMultimapKeysUserStateWriter(ch, id, s.instID, userStateID, key, w, writeTypeClear)
+	})
+	return wr, err
+}
+
 // GetSideInputCache returns a pointer to the SideInputCache being used by the SDK harness.
 func (s *ScopedStateReader) GetSideInputCache() exec.SideCache {
 	return s.cache
@@ -273,6 +313,82 @@ func newBagUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructio
 	}
 }
 
+func newMultimapUserStateReader(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, mk []byte) *stateKeyReader {
+	key := &fnpb.StateKey{
+		Type: &fnpb.StateKey_MultimapUserState_{
+			MultimapUserState: &fnpb.StateKey_MultimapUserState{
+				TransformId: id.PtransformID,
+				UserStateId: userStateID,
+				Window:      w,
+				Key:         k,
+				MapKey:      mk,
+			},
+		},
+	}
+	return &stateKeyReader{
+		instID: instID,
+		key:    key,
+		ch:     ch,
+	}
+}
+
+func newMultimapUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, mk []byte, wt writeTypeEnum) *stateKeyWriter {
+	key := &fnpb.StateKey{
+		Type: &fnpb.StateKey_MultimapUserState_{
+			MultimapUserState: &fnpb.StateKey_MultimapUserState{
+				TransformId: id.PtransformID,
+				UserStateId: userStateID,
+				Window:      w,
+				Key:         k,
+				MapKey:      mk,
+			},
+		},
+	}
+	return &stateKeyWriter{
+		instID:    instID,
+		key:       key,
+		ch:        ch,
+		writeType: wt,
+	}
+}
+
+func newMultimapKeysUserStateReader(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte) *stateKeyReader {
+	key := &fnpb.StateKey{
+		Type: &fnpb.StateKey_MultimapKeysUserState_{
+			MultimapKeysUserState: &fnpb.StateKey_MultimapKeysUserState{
+				TransformId: id.PtransformID,
+				UserStateId: userStateID,
+				Window:      w,
+				Key:         k,
+			},
+		},
+	}
+	return &stateKeyReader{
+		instID: instID,
+		key:    key,
+		ch:     ch,
+	}
+}
+
+func newMultimapKeysUserStateWriter(ch *StateChannel, id exec.StreamID, instID instructionID, userStateID string, k []byte, w []byte, wt writeTypeEnum) *stateKeyWriter {
+	key := &fnpb.StateKey{
+		Type: &fnpb.StateKey_MultimapKeysUserState_{
+			MultimapKeysUserState: &fnpb.StateKey_MultimapKeysUserState{
+				TransformId: id.PtransformID,
+				UserStateId: userStateID,
+				Window:      w,
+				Key:         k,
+			},
+		},
+	}
+	return &stateKeyWriter{
+		instID:    instID,
+		key:       key,
+		ch:        ch,
+		writeType: wt,
+	}
+}
+
 func (r *stateKeyReader) Read(buf []byte) (int, error) {
 	if r.buf == nil {
 		if r.eof {
diff --git a/sdks/go/pkg/beam/core/state/state.go b/sdks/go/pkg/beam/core/state/state.go
index 8f960cfc739..2e700250c84 100644
--- a/sdks/go/pkg/beam/core/state/state.go
+++ b/sdks/go/pkg/beam/core/state/state.go
@@ -42,6 +42,8 @@ const (
 	TypeBag TypeEnum = 1
 	// TypeCombining represents a combining state
 	TypeCombining TypeEnum = 2
+	// TypeMap represents a map state
+	TypeMap TypeEnum = 3
 )
 
 var (
@@ -54,9 +56,10 @@ var (
 // Transaction is used to represent a pending state transaction. This should not be manipulated directly;
 // it is primarily used for implementations of the Provider interface to talk to the various State objects.
 type Transaction struct {
-	Key  string
-	Type TransactionTypeEnum
-	Val  interface{}
+	Key    string
+	Type   TransactionTypeEnum
+	MapKey interface{}
+	Val    interface{}
 }
 
 // Provider represents the DoFn parameter used to get and manipulate pipeline state
@@ -72,12 +75,16 @@ type Provider interface {
 	AddInputFn(userStateID string) reflectx.Func
 	MergeAccumulatorsFn(userStateID string) reflectx.Func
 	ExtractOutputFn(userStateID string) reflectx.Func
+	ReadMapStateValue(userStateID string, key interface{}) (interface{}, []Transaction, error)
+	ReadMapStateKeys(userStateID string) ([]interface{}, []Transaction, error)
+	WriteMapState(val Transaction) error
 }
 
 // PipelineState is an interface representing different kinds of PipelineState (currently just state.Value).
 // It is primarily meant for Beam packages to use and is probably not useful for most pipeline authors.
 type PipelineState interface {
 	StateKey() string
+	KeyCoderType() reflect.Type
 	CoderType() reflect.Type
 	StateType() TypeEnum
 }
@@ -137,6 +144,11 @@ func (s Value[T]) StateKey() string {
 	return s.Key
 }
 
+// KeyCoderType returns nil since Value types aren't keyed.
+func (s Value[T]) KeyCoderType() reflect.Type {
+	return nil
+}
+
 // CoderType returns the type of the value state which should be used for a coder.
 func (s Value[T]) CoderType() reflect.Type {
 	var t T
@@ -207,6 +219,11 @@ func (s Bag[T]) StateKey() string {
 	return s.Key
 }
 
+// KeyCoderType returns nil since Bag types aren't keyed.
+func (s Bag[T]) KeyCoderType() reflect.Type {
+	return nil
+}
+
 // CoderType returns the type of the bag state which should be used for a coder.
 func (s Bag[T]) CoderType() reflect.Type {
 	var t T
@@ -343,6 +360,11 @@ func (s Combining[T1, T2, T3]) StateKey() string {
 	return s.Key
 }
 
+// KeyCoderType returns nil since combining state types aren't keyed.
+func (s Combining[T1, T2, T3]) KeyCoderType() reflect.Type {
+	return nil
+}
+
 // CoderType returns the type of the bag state which should be used for a coder.
 func (s Combining[T1, T2, T3]) CoderType() reflect.Type {
 	var t T1
@@ -369,3 +391,129 @@ func MakeCombiningState[T1, T2, T3 any](k string, combiner interface{}) Combinin
 		combineFn: combiner,
 	}
 }
+
+// Map is used to read and write global pipeline state representing a map.
+// Key represents the key used to lookup this state (not the key of map entries).
+type Map[K comparable, V any] struct {
+	Key string
+}
+
+// Put is used to write a key/value pair to this instance of global map state.
+func (s *Map[K, V]) Put(p Provider, key K, val V) error {
+	return p.WriteMapState(Transaction{
+		Key:    s.Key,
+		Type:   TransactionTypeSet,
+		MapKey: key,
+		Val:    val,
+	})
+}
+
+// Keys is used to read the keys of this map state.
+// When a value is not found, returns an empty list and false.
+func (s *Map[K, V]) Keys(p Provider) ([]K, bool, error) {
+	// This replays any writes that have happened to this value since we last read
+	// For more detail, see "State Transactionality" below for buffered transactions
+	initialValue, bufferedTransactions, err := p.ReadMapStateKeys(s.Key)
+	if err != nil {
+		return []K{}, false, err
+	}
+	cur := []K{}
+	for _, v := range initialValue {
+		cur = append(cur, v.(K))
+	}
+	for _, t := range bufferedTransactions {
+		switch t.Type {
+		case TransactionTypeSet:
+			seen := false
+			mk := t.MapKey.(K)
+			for _, k := range cur {
+				if k == mk {
+					seen = true
+				}
+			}
+			if !seen {
+				cur = append(cur, mk)
+			}
+		case TransactionTypeClear:
+			if t.MapKey == nil {
+				cur = []K{}
+			} else {
+				k := t.MapKey.(K)
+				for idx, v := range cur {
+					if v == k {
+						// Remove this key since its been cleared
+						cur[idx] = cur[len(cur)-1]
+						cur = cur[:len(cur)-1]
+						break
+					}
+				}
+			}
+		}
+	}
+	if len(cur) == 0 {
+		return cur, false, nil
+	}
+	return cur, true, nil
+}
+
+// Get is used to read a value given a key.
+// When a value is not found, returns the 0 value and false.
+func (s *Map[K, V]) Get(p Provider, key K) (V, bool, error) {
+	// This replays any writes that have happened to this value since we last read
+	// For more detail, see "State Transactionality" below for buffered transactions
+	cur, bufferedTransactions, err := p.ReadMapStateValue(s.Key, key)
+	if err != nil {
+		var val V
+		return val, false, err
+	}
+	for _, t := range bufferedTransactions {
+		switch t.Type {
+		case TransactionTypeSet:
+			if t.MapKey.(K) == key {
+				cur = t.Val
+			}
+		case TransactionTypeClear:
+			if t.MapKey == nil || t.MapKey.(K) == key {
+				cur = nil
+			}
+		}
+	}
+	if cur == nil {
+		var val V
+		return val, false, nil
+	}
+	return cur.(V), true, nil
+}
+
+// StateKey returns the key for this pipeline state entry.
+func (s Map[K, V]) StateKey() string {
+	if s.Key == "" {
+		// TODO(#22736) - infer the state from the member variable name during pipeline construction.
+		panic("Value state exists on struct but has not been initialized with a key.")
+	}
+	return s.Key
+}
+
+// KeyCoderType returns the type of the value state which should be used for a coder for map keys.
+func (s Map[K, V]) KeyCoderType() reflect.Type {
+	var k K
+	return reflect.TypeOf(k)
+}
+
+// CoderType returns the type of the value state which should be used for a coder for map values.
+func (s Map[K, V]) CoderType() reflect.Type {
+	var v V
+	return reflect.TypeOf(v)
+}
+
+// StateType returns the type of the state (in this case always Map).
+func (s Map[K, V]) StateType() TypeEnum {
+	return TypeMap
+}
+
+// MakeValueState is a factory function to create an instance of ValueState with the given key.
+func MakeMapState[K comparable, V any](k string) Map[K, V] {
+	return Map[K, V]{
+		Key: k,
+	}
+}
diff --git a/sdks/go/pkg/beam/core/state/state_test.go b/sdks/go/pkg/beam/core/state/state_test.go
index 9ee99d790cf..d95677ce3a0 100644
--- a/sdks/go/pkg/beam/core/state/state_test.go
+++ b/sdks/go/pkg/beam/core/state/state_test.go
@@ -29,6 +29,7 @@ var (
 type fakeProvider struct {
 	initialState      map[string]interface{}
 	initialBagState   map[string][]interface{}
+	initialMapState   map[string]map[string]interface{}
 	transactions      map[string][]Transaction
 	err               map[string]error
 	createAccumForKey map[string]bool
@@ -116,6 +117,46 @@ func (s *fakeProvider) ExtractOutputFn(userStateID string) reflectx.Func {
 	return nil
 }
 
+func (s *fakeProvider) ReadMapStateValue(userStateID string, key interface{}) (interface{}, []Transaction, error) {
+	keyString := key.(string)
+	if err, ok := s.err[userStateID]; ok {
+		return nil, nil, err
+	}
+	base := s.initialMapState[userStateID][keyString]
+	trans, ok := s.transactions[userStateID]
+	if !ok {
+		trans = []Transaction{}
+	}
+	return base, trans, nil
+}
+
+func (s *fakeProvider) ReadMapStateKeys(userStateID string) ([]interface{}, []Transaction, error) {
+	if err, ok := s.err[userStateID]; ok {
+		return nil, nil, err
+	}
+	base := s.initialMapState[userStateID]
+	keys := make([]interface{}, len(base))
+	i := 0
+	for k := range base {
+		keys[i] = k
+		i++
+	}
+	trans, ok := s.transactions[userStateID]
+	if !ok {
+		trans = []Transaction{}
+	}
+	return keys, trans, nil
+}
+
+func (s *fakeProvider) WriteMapState(val Transaction) error {
+	if transactions, ok := s.transactions[val.Key]; ok {
+		s.transactions[val.Key] = append(transactions, val)
+	} else {
+		s.transactions[val.Key] = []Transaction{val}
+	}
+	return nil
+}
+
 func TestValueRead(t *testing.T) {
 	is := make(map[string]interface{})
 	ts := make(map[string][]Transaction)
@@ -475,3 +516,171 @@ func TestCombiningAdd(t *testing.T) {
 		}
 	}
 }
+
+func TestMapGet(t *testing.T) {
+	is := make(map[string]interface{})
+	im := make(map[string]map[string]interface{})
+	ts := make(map[string][]Transaction)
+	es := make(map[string]error)
+	ca := make(map[string]bool)
+	eo := make(map[string]bool)
+	ts["no_transactions"] = nil
+	im["basic_set"] = map[string]interface{}{"foo": 2}
+	ts["basic_set"] = []Transaction{{Key: "basic_set", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "basic_set", Type: TransactionTypeSet, Val: 1, MapKey: "bar"}}
+	im["basic_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+	ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+	im["set_then_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+	ts["set_then_clear"] = []Transaction{{Key: "set_then_clear", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "set_then_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+	im["err"] = map[string]interface{}{"foo": 2}
+	ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeSet, Val: 3}}
+	es["err"] = errFake
+
+	f := fakeProvider{
+		initialMapState:   im,
+		initialState:      is,
+		transactions:      ts,
+		err:               es,
+		createAccumForKey: ca,
+		extractOutForKey:  eo,
+	}
+
+	var tests = []struct {
+		vs    Map[string, int]
+		foo   int
+		bar   int
+		fooOk bool
+		barOk bool
+		err   error
+	}{
+		{MakeMapState[string, int]("no_transactions"), 0, 0, false, false, nil},
+		{MakeMapState[string, int]("basic_set"), 3, 1, true, true, nil},
+		{MakeMapState[string, int]("basic_clear"), 0, 1, false, true, nil},
+		{MakeMapState[string, int]("set_then_clear"), 0, 1, false, true, nil},
+		{MakeMapState[string, int]("err"), 0, 0, false, false, errFake},
+	}
+
+	for _, tt := range tests {
+		val, ok, err := tt.vs.Get(&f, "foo")
+		if err != nil && tt.err == nil {
+			t.Errorf("Map.Get() returned error %v for state key %v and map key foo when it shouldn't have", err, tt.vs.Key)
+		} else if err == nil && tt.err != nil {
+			t.Errorf("Map.Get() returned no error for state key %v and map key foo when it should have returned %v", tt.vs.Key, err)
+		} else if ok && !tt.fooOk {
+			t.Errorf("Map.Get() returned a value %v for state key %v and map key foo when it shouldn't have", val, tt.vs.Key)
+		} else if !ok && tt.fooOk {
+			t.Errorf("Map.Get() didn't return a value for state key %v and map key foo when it should have returned %v", tt.vs.Key, tt.foo)
+		} else if val != tt.foo {
+			t.Errorf("Map.Get()=%v, want %v for state key %v and map key foo", val, tt.foo, tt.vs.Key)
+		}
+		val, ok, err = tt.vs.Get(&f, "bar")
+		if err != nil && tt.err == nil {
+			t.Errorf("Map.Get() returned error %v for state key %v and map key bar when it shouldn't have", err, tt.vs.Key)
+		} else if err == nil && tt.err != nil {
+			t.Errorf("Map.Get() returned no error for state key %v and map key bar when it should have returned %v", tt.vs.Key, err)
+		} else if ok && !tt.barOk {
+			t.Errorf("Map.Get() returned a value %v for state key %v and map key bar when it shouldn't have", val, tt.vs.Key)
+		} else if !ok && tt.barOk {
+			t.Errorf("Map.Get() didn't return a value for state key %v and map key bar when it should have returned %v", tt.vs.Key, tt.bar)
+		} else if val != tt.bar {
+			t.Errorf("Map.Get()=%v, want %v for state key %v and map key bar", val, tt.bar, tt.vs.Key)
+		}
+	}
+}
+
+func TestMapKeys(t *testing.T) {
+	is := make(map[string]interface{})
+	im := make(map[string]map[string]interface{})
+	ts := make(map[string][]Transaction)
+	es := make(map[string]error)
+	ca := make(map[string]bool)
+	eo := make(map[string]bool)
+	ts["no_transactions"] = nil
+	im["basic_set"] = map[string]interface{}{"foo": 2}
+	ts["basic_set"] = []Transaction{{Key: "basic_set", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "basic_set", Type: TransactionTypeSet, Val: 1, MapKey: "bar"}}
+	im["basic_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+	ts["basic_clear"] = []Transaction{{Key: "basic_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+	im["set_then_clear"] = map[string]interface{}{"foo": 2, "bar": 1}
+	ts["set_then_clear"] = []Transaction{{Key: "set_then_clear", Type: TransactionTypeSet, Val: 3, MapKey: "foo"}, {Key: "set_then_clear", Type: TransactionTypeClear, Val: nil, MapKey: "foo"}}
+	im["err"] = map[string]interface{}{"foo": 2}
+	ts["err"] = []Transaction{{Key: "err", Type: TransactionTypeSet, Val: 3}}
+	es["err"] = errFake
+
+	f := fakeProvider{
+		initialMapState:   im,
+		initialState:      is,
+		transactions:      ts,
+		err:               es,
+		createAccumForKey: ca,
+		extractOutForKey:  eo,
+	}
+
+	var tests = []struct {
+		vs   Map[string, int]
+		keys []string
+		ok   bool
+		err  error
+	}{
+		{MakeMapState[string, int]("no_transactions"), []string{}, false, nil},
+		{MakeMapState[string, int]("basic_set"), []string{"foo", "bar"}, true, nil},
+		{MakeMapState[string, int]("basic_clear"), []string{"bar"}, true, nil},
+		{MakeMapState[string, int]("set_then_clear"), []string{"bar"}, true, nil},
+		{MakeMapState[string, int]("err"), []string{}, false, errFake},
+	}
+
+	for _, tt := range tests {
+		val, ok, err := tt.vs.Keys(&f)
+		if err != nil && tt.err == nil {
+			t.Errorf("Map.Keys() returned error %v for state key %v when it shouldn't have", err, tt.vs.Key)
+		} else if err == nil && tt.err != nil {
+			t.Errorf("Map.Keys() returned no error for state key %v when it should have returned %v", tt.vs.Key, err)
+		} else if ok && !tt.ok {
+			t.Errorf("Map.Keys() returned a value %v for state key %v when it shouldn't have", val, tt.vs.Key)
+		} else if len(val) != len(tt.keys) {
+			t.Errorf("Map.Keys()=%v, want %v for state key %v", val, tt.keys, tt.vs.Key)
+		} else {
+			eq := true
+			for idx, v := range val {
+				if v != tt.keys[idx] {
+					eq = false
+				}
+			}
+			if !eq {
+				t.Errorf("Map.Keys()=%v, want %v for state key %v", val, tt.keys, tt.vs.Key)
+			}
+		}
+	}
+}
+
+func TestMapPut(t *testing.T) {
+	var tests = []struct {
+		writes []int
+		val    int
+		ok     bool
+	}{
+		{[]int{}, 0, false},
+		{[]int{3}, 3, true},
+		{[]int{1, 5}, 5, true},
+	}
+
+	for _, tt := range tests {
+		f := fakeProvider{
+			initialState: make(map[string]interface{}),
+			transactions: make(map[string][]Transaction),
+			err:          make(map[string]error),
+		}
+		vs := MakeMapState[string, int]("vs")
+		for _, val := range tt.writes {
+			vs.Put(&f, "foo", val)
+		}
+		val, ok, err := vs.Get(&f, "foo")
+		if err != nil {
+			t.Errorf("Map.Get(\"foo\") returned error %v when it shouldn't have after writing: %v", err, tt.writes)
+		} else if ok && !tt.ok {
+			t.Errorf("Map.Get(\"foo\") returned a value %v when it shouldn't have after writing: %v", val, tt.writes)
+		} else if !ok && tt.ok {
+			t.Errorf("Map.Get(\"foo\") didn't return a value when it should have returned %v after writing: %v", tt.val, tt.writes)
+		} else if val != tt.val {
+			t.Errorf("Map.Get(\"foo\")=%v, want %v after writing: %v", val, tt.val, tt.writes)
+		}
+	}
+}
diff --git a/sdks/go/pkg/beam/pardo.go b/sdks/go/pkg/beam/pardo.go
index dcdb3e74d1e..47c4dee0fd7 100644
--- a/sdks/go/pkg/beam/pardo.go
+++ b/sdks/go/pkg/beam/pardo.go
@@ -23,6 +23,7 @@ import (
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
 	"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
@@ -101,7 +102,15 @@ func TryParDo(s Scope, dofn interface{}, col PCollection, opts ...Option) ([]PCo
 			if err != nil {
 				return nil, addParDoCtx(err, s)
 			}
-			edge.StateCoders[ps.StateKey()] = c
+			edge.StateCoders[graphx.UserStateCoderId(ps)] = c
+			if kct := ps.KeyCoderType(); kct != nil {
+				kT := typex.New(kct)
+				kc, err := inferCoder(kT)
+				if err != nil {
+					return nil, addParDoCtx(err, s)
+				}
+				edge.StateCoders[graphx.UserStateKeyCoderId(ps)] = kc
+			}
 		}
 	}