You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2022/02/15 04:35:24 UTC

[beam] branch release-2.37.0 updated: [BEAM-13930] Address StateSpec consistency issue between Runner and Fn API. (#16852)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch release-2.37.0
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.37.0 by this push:
     new 5161a11  [BEAM-13930] Address StateSpec consistency issue between Runner and Fn API. (#16852)
5161a11 is described below

commit 5161a11b0b90d888c01d7f4c29b6a40e854517f0
Author: Lukasz Cwik <lc...@google.com>
AuthorDate: Mon Feb 14 20:34:08 2022 -0800

    [BEAM-13930] Address StateSpec consistency issue between Runner and Fn API. (#16852)
    
    The ability to mix and match runners and SDKs is accomplished through two portability layers:
    1. The Runner API provides an SDK-and-runner-independent definition of a Beam pipeline
    2. The Fn API allows a runner to invoke SDK-specific user-defined functions
    
    Apache Beam pipelines support executing stateful DoFns[1]. To support this execution the Runner API defines multiple user state specifications:
    * ReadModifyWriteStateSpec
    * BagStateSpec
    * OrderedListStateSpec
    * CombiningStateSpec
    * MapStateSpec
    * SetStateSpec
    
    The Fn API[2] defines APIs[3] to get, append and clear user state currently supporting a BagUserState and MultimapUserState protocol.
    
    Since there is no clear mapping between the Runner API and Fn API state specifications, there is no way for a runner to know that it supports a given API necessary to support the execution of the pipeline. The Runner will also have to manage additional runtime metadata associated with which protocol was used for a type of state so that it can successfully manage the state’s lifetime once it can be garbage collected.
    
    Please see the doc[4] for further details and a proposal on how to address this shortcoming.
    
    1: https://beam.apache.org/blog/stateful-processing/
    2: https://github.com/apache/beam/blob/3ad05523f4cdf5122fc319276fcb461f768af39d/model/fn-execution/src/main/proto/beam_fn_api.proto#L742
    3: https://s.apache.org/beam-fn-state-api-and-bundle-processing
    4: http://doc/1ELKTuRTV3C5jt_YoBBwPdsPa5eoXCCOSKQ3GPzZrK7Q
---
 .../pipeline/src/main/proto/beam_runner_api.proto  | 29 +++++++++++++++++++
 .../core/construction/ParDoTranslation.java        | 15 ++++++++++
 .../core/construction/ParDoTranslationTest.java    | 33 ++++++++++++++++++----
 sdks/python/apache_beam/portability/common_urns.py |  2 ++
 sdks/python/apache_beam/transforms/userstate.py    | 17 ++++++++---
 .../apache_beam/transforms/userstate_test.py       | 33 ++++++++++++++++++++++
 6 files changed, 119 insertions(+), 10 deletions(-)

diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto
index b459883..105bdbd 100644
--- a/model/pipeline/src/main/proto/beam_runner_api.proto
+++ b/model/pipeline/src/main/proto/beam_runner_api.proto
@@ -457,6 +457,24 @@ message StandardSideInputTypes {
   }
 }
 
+message StandardUserStateTypes {
+  enum Enum {
+    // Represents a user state specification that supports a bag.
+    //
+    // StateRequests performed on this user state must use
+    // StateKey.BagUserState.
+    BAG = 0 [(beam_urn) = "beam:user_state:bag:v1"];
+
+    // Represents a user state specification that supports a multimap.
+    //
+    // StateRequests performed on this user state must use
+    // StateKey.MultimapKeysUserState or StateKey.MultimapUserState.
+    MULTIMAP = 1 [(beam_urn) = "beam:user_state:multimap:v1"];
+
+    // TODO(BEAM-10650): Add protocol to support OrderedListState
+  }
+}
+
 // A PCollection!
 message PCollection {
 
@@ -534,6 +552,7 @@ message ParDoPayload {
 }
 
 message StateSpec {
+  // TODO(BEAM-13930): Deprecate and remove these state specs
   oneof spec {
     ReadModifyWriteStateSpec read_modify_write_spec = 1;
     BagStateSpec bag_spec = 2;
@@ -542,6 +561,16 @@ message StateSpec {
     SetStateSpec set_spec = 5;
     OrderedListStateSpec ordered_list_spec = 6;
   }
+
+  // (Required) URN of the protocol required by this state specification to present
+  // the desired SDK-specific interface to a UDF.
+  //
+  // This protocol defines the SDK harness <-> Runner Harness RPC
+  // interface for accessing and mutating user state.
+  //
+  // See StandardUserStateTypes for an enumeration of all user state types
+  // defined.
+  FunctionSpec protocol = 7;
 }
 
 message ReadModifyWriteStateSpec {
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
index 1c191b1..ffab927 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java
@@ -46,6 +46,7 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
 import org.apache.beam.model.pipeline.v1.RunnerApi.StandardRequirements;
+import org.apache.beam.model.pipeline.v1.RunnerApi.StandardUserStateTypes;
 import org.apache.beam.runners.core.construction.PTransformTranslation.TransformPayloadTranslator;
 import org.apache.beam.runners.core.construction.PTransformTranslation.TransformTranslator;
 import org.apache.beam.sdk.Pipeline;
@@ -122,6 +123,11 @@ public class ParDoTranslation {
   public static final String REQUIRES_ON_WINDOW_EXPIRATION_URN =
       "beam:requirement:pardo:on_window_expiration:v1";
 
+  /** Represents a user state specification that supports a bag. */
+  public static final String BAG_USER_STATE = "beam:user_state:bag:v1";
+  /** Represents a user state specification that supports a multimap. */
+  public static final String MULTIMAP_USER_STATE = "beam:user_state:multimap:v1";
+
   static {
     checkState(
         REQUIRES_STATEFUL_PROCESSING_URN.equals(
@@ -140,6 +146,8 @@ public class ParDoTranslation {
     checkState(
         REQUIRES_ON_WINDOW_EXPIRATION_URN.equals(
             getUrn(StandardRequirements.Enum.REQUIRES_ON_WINDOW_EXPIRATION)));
+    checkState(BAG_USER_STATE.equals(getUrn(StandardUserStateTypes.Enum.BAG)));
+    checkState(MULTIMAP_USER_STATE.equals(getUrn(StandardUserStateTypes.Enum.MULTIMAP)));
   }
 
   /** The URN for an unknown Java {@link DoFn}. */
@@ -571,6 +579,7 @@ public class ParDoTranslation {
                 .setReadModifyWriteSpec(
                     RunnerApi.ReadModifyWriteStateSpec.newBuilder()
                         .setCoderId(registerCoderOrThrow(components, valueCoder)))
+                .setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
                 .build();
           }
 
@@ -580,6 +589,7 @@ public class ParDoTranslation {
                 .setBagSpec(
                     RunnerApi.BagStateSpec.newBuilder()
                         .setElementCoderId(registerCoderOrThrow(components, elementCoder)))
+                .setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
                 .build();
           }
 
@@ -589,6 +599,8 @@ public class ParDoTranslation {
                 .setOrderedListSpec(
                     RunnerApi.OrderedListStateSpec.newBuilder()
                         .setElementCoderId(registerCoderOrThrow(components, elementCoder)))
+                // TODO(BEAM-10650): Update with correct protocol once the protocol is defined and
+                // the SDK harness uses it.
                 .build();
           }
 
@@ -600,6 +612,7 @@ public class ParDoTranslation {
                     RunnerApi.CombiningStateSpec.newBuilder()
                         .setAccumulatorCoderId(registerCoderOrThrow(components, accumCoder))
                         .setCombineFn(CombineTranslation.toProto(combineFn, components)))
+                .setProtocol(FunctionSpec.newBuilder().setUrn(BAG_USER_STATE))
                 .build();
           }
 
@@ -610,6 +623,7 @@ public class ParDoTranslation {
                     RunnerApi.MapStateSpec.newBuilder()
                         .setKeyCoderId(registerCoderOrThrow(components, keyCoder))
                         .setValueCoderId(registerCoderOrThrow(components, valueCoder)))
+                .setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
                 .build();
           }
 
@@ -619,6 +633,7 @@ public class ParDoTranslation {
                 .setSetSpec(
                     RunnerApi.SetStateSpec.newBuilder()
                         .setElementCoderId(registerCoderOrThrow(components, elementCoder)))
+                .setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
                 .build();
           }
         });
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
index 837c3aa..73b6c0f 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java
@@ -24,9 +24,11 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
+import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
 import org.apache.beam.model.pipeline.v1.RunnerApi.SideInput;
 import org.apache.beam.runners.core.construction.CoderTranslation.TranslationContext;
@@ -225,16 +227,33 @@ public class ParDoTranslationTest {
   public static class TestStateAndTimerTranslation {
 
     @Parameters(name = "{index}: {0}")
-    public static Iterable<StateSpec<?>> stateSpecs() {
-      return ImmutableList.of(
-          StateSpecs.value(VarIntCoder.of()),
-          StateSpecs.bag(VarIntCoder.of()),
-          StateSpecs.set(VarIntCoder.of()),
-          StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()));
+    public static Iterable<Object[]> stateSpecs() {
+      return Arrays.asList(
+          new Object[][] {
+            {
+              StateSpecs.value(VarIntCoder.of()),
+              FunctionSpec.newBuilder().setUrn(ParDoTranslation.BAG_USER_STATE).build()
+            },
+            {
+              StateSpecs.bag(VarIntCoder.of()),
+              FunctionSpec.newBuilder().setUrn(ParDoTranslation.BAG_USER_STATE).build()
+            },
+            {
+              StateSpecs.set(VarIntCoder.of()),
+              FunctionSpec.newBuilder().setUrn(ParDoTranslation.MULTIMAP_USER_STATE).build()
+            },
+            {
+              StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()),
+              FunctionSpec.newBuilder().setUrn(ParDoTranslation.MULTIMAP_USER_STATE).build()
+            }
+          });
     }
 
     @Parameter public StateSpec<?> stateSpec;
 
+    @Parameter(1)
+    public FunctionSpec protocol;
+
     @Test
     public void testStateSpecToFromProto() throws Exception {
       // Encode
@@ -243,6 +262,8 @@ public class ParDoTranslationTest {
       RunnerApi.StateSpec stateSpecProto =
           ParDoTranslation.translateStateSpec(stateSpec, sdkComponents);
 
+      assertEquals(stateSpecProto.getProtocol(), protocol);
+
       // Decode
       RehydratedComponents rehydratedComponents =
           RehydratedComponents.forComponents(sdkComponents.toComponents());
diff --git a/sdks/python/apache_beam/portability/common_urns.py b/sdks/python/apache_beam/portability/common_urns.py
index 4e23c4f..1dd46cd 100644
--- a/sdks/python/apache_beam/portability/common_urns.py
+++ b/sdks/python/apache_beam/portability/common_urns.py
@@ -29,6 +29,7 @@ from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardPTransf
 from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardRequirements
 from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardResourceHints
 from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardSideInputTypes
+from apache_beam.portability.api.beam_runner_api_pb2_urns import StandardUserStateTypes
 from apache_beam.portability.api.external_transforms_pb2_urns import ExpansionMethods
 from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfo
 from apache_beam.portability.api.metrics_pb2_urns import MonitoringInfoSpecs
@@ -45,6 +46,7 @@ combine_components = StandardPTransforms.CombineComponents
 sdf_components = StandardPTransforms.SplittableParDoComponents
 group_into_batches_components = StandardPTransforms.GroupIntoBatchesComponents
 
+user_state = StandardUserStateTypes.Enum
 side_inputs = StandardSideInputTypes.Enum
 coders = StandardCoders.Enum
 constants = BeamConstants.Constants
diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py
index 1276118..8e088d1 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -38,6 +38,7 @@ from typing import TypeVar
 
 from apache_beam.coders import Coder
 from apache_beam.coders import coders
+from apache_beam.portability import common_urns
 from apache_beam.portability.api import beam_runner_api_pb2
 from apache_beam.transforms.timeutil import TimeDomain
 
@@ -76,7 +77,9 @@ class ReadModifyWriteStateSpec(StateSpec):
     # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
     return beam_runner_api_pb2.StateSpec(
         read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
-            coder_id=context.coders.get_id(self.coder)))
+            coder_id=context.coders.get_id(self.coder)),
+        protocol=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.user_state.BAG.urn))
 
 
 class BagStateSpec(StateSpec):
@@ -85,7 +88,9 @@ class BagStateSpec(StateSpec):
     # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
     return beam_runner_api_pb2.StateSpec(
         bag_spec=beam_runner_api_pb2.BagStateSpec(
-            element_coder_id=context.coders.get_id(self.coder)))
+            element_coder_id=context.coders.get_id(self.coder)),
+        protocol=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.user_state.BAG.urn))
 
 
 class SetStateSpec(StateSpec):
@@ -94,7 +99,9 @@ class SetStateSpec(StateSpec):
     # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
     return beam_runner_api_pb2.StateSpec(
         set_spec=beam_runner_api_pb2.SetStateSpec(
-            element_coder_id=context.coders.get_id(self.coder)))
+            element_coder_id=context.coders.get_id(self.coder)),
+        protocol=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.user_state.BAG.urn))
 
 
 class CombiningValueStateSpec(StateSpec):
@@ -141,7 +148,9 @@ class CombiningValueStateSpec(StateSpec):
     return beam_runner_api_pb2.StateSpec(
         combining_spec=beam_runner_api_pb2.CombiningStateSpec(
             combine_fn=self.combine_fn.to_runner_api(context),
-            accumulator_coder_id=context.coders.get_id(self.coder)))
+            accumulator_coder_id=context.coders.get_id(self.coder)),
+        protocol=beam_runner_api_pb2.FunctionSpec(
+            urn=common_urns.user_state.BAG.urn))
 
 
 # TODO(BEAM-9562): Update Timer to have of() and clear() APIs.
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index fbddd02..24e996e 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -30,6 +30,9 @@ from apache_beam.coders import ListCoder
 from apache_beam.coders import StrUtf8Coder
 from apache_beam.coders import VarIntCoder
 from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.portability import common_urns
+from apache_beam.portability.api import beam_runner_api_pb2
+from apache_beam.runners import pipeline_context
 from apache_beam.runners.common import DoFnSignature
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.test_stream import TestStream
@@ -157,6 +160,36 @@ class InterfaceTest(unittest.TestCase):
     with self.assertRaises(ValueError):
       DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))
 
+  def test_state_spec_proto_conversion(self):
+    context = pipeline_context.PipelineContext()
+    state = BagStateSpec('statename', VarIntCoder())
+    state_proto = state.to_runner_api(context)
+    self.assertEquals(
+        beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
+        state_proto.protocol)
+
+    context = pipeline_context.PipelineContext()
+    state = CombiningValueStateSpec(
+        'statename', VarIntCoder(), TopCombineFn(10))
+    state_proto = state.to_runner_api(context)
+    self.assertEquals(
+        beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
+        state_proto.protocol)
+
+    context = pipeline_context.PipelineContext()
+    state = SetStateSpec('setstatename', VarIntCoder())
+    state_proto = state.to_runner_api(context)
+    self.assertEquals(
+        beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
+        state_proto.protocol)
+
+    context = pipeline_context.PipelineContext()
+    state = ReadModifyWriteStateSpec('valuestatename', VarIntCoder())
+    state_proto = state.to_runner_api(context)
+    self.assertEquals(
+        beam_runner_api_pb2.FunctionSpec(urn=common_urns.user_state.BAG.urn),
+        state_proto.protocol)
+
   def test_param_construction(self):
     with self.assertRaises(ValueError):
       DoFn.StateParam(TimerSpec('timer', TimeDomain.WATERMARK))