You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by yi...@apache.org on 2022/05/05 20:41:37 UTC

[beam] branch master updated: [BEAM-12603] Add retry on grpc data channel and remove retry from test. (#17537)

This is an automated email from the ASF dual-hosted git repository.

yichi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 9154f8b3cf0 [BEAM-12603] Add retry on grpc data channel and remove retry from test. (#17537)
9154f8b3cf0 is described below

commit 9154f8b3cf01fce067cea7ad7522e1db3154ff89
Author: Yichi Zhang <zy...@google.com>
AuthorDate: Thu May 5 13:41:30 2022 -0700

    [BEAM-12603] Add retry on grpc data channel and remove retry from test. (#17537)
---
 .../portability/fn_api_runner/fn_runner_test.py    | 53 ----------------------
 .../apache_beam/runners/worker/data_plane.py       | 25 +++++++++-
 .../apache_beam/runners/worker/sdk_worker.py       | 20 +++++++-
 3 files changed, 43 insertions(+), 55 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
index 06016706d03..3a8415e61b8 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner/fn_runner_test.py
@@ -105,7 +105,6 @@ class FnApiRunnerTest(unittest.TestCase):
   def create_pipeline(self, is_drain=False):
     return beam.Pipeline(runner=fn_api_runner.FnApiRunner(is_drain=is_drain))
 
-  @retry(stop=stop_after_attempt(3))
   def test_assert_that(self):
     # TODO: figure out a way for fn_api_runner to parse and raise the
     # underlying exception.
@@ -113,12 +112,10 @@ class FnApiRunnerTest(unittest.TestCase):
       with self.create_pipeline() as p:
         assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
 
-  @retry(stop=stop_after_attempt(3))
   def test_create(self):
     with self.create_pipeline() as p:
       assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo(self):
     with self.create_pipeline() as p:
       res = (
@@ -296,7 +293,6 @@ class FnApiRunnerTest(unittest.TestCase):
                                  9*9                        # [ 9, 14)
                                  ]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_side_outputs(self):
     def tee(elem, *tags):
       for tag in tags:
@@ -311,7 +307,6 @@ class FnApiRunnerTest(unittest.TestCase):
       assert_that(xy.x, equal_to(['x', 'xy']), label='x')
       assert_that(xy.y, equal_to(['y', 'xy']), label='y')
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_side_and_main_outputs(self):
     def even_odd(elem):
       yield elem
@@ -331,7 +326,6 @@ class FnApiRunnerTest(unittest.TestCase):
       assert_that(unnamed.even, equal_to([2]), label='unnamed.even')
       assert_that(unnamed.odd, equal_to([1, 3]), label='unnamed.odd')
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_side_inputs(self):
     def cross_product(elem, sides):
       for side in sides:
@@ -371,7 +365,6 @@ class FnApiRunnerTest(unittest.TestCase):
                 *[beam.pvalue.AsList(inputs[s]) for s in range(1, k)]))
 
   @unittest.skip('BEAM-13040')
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_side_input_sparse_dependencies(self):
     with self.create_pipeline() as p:
       inputs = []
@@ -392,7 +385,6 @@ class FnApiRunnerTest(unittest.TestCase):
                       for s in range(1, num_inputs)
                   ]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_windowed_side_inputs(self):
     with self.create_pipeline() as p:
       # Now with some windowing.
@@ -421,7 +413,6 @@ class FnApiRunnerTest(unittest.TestCase):
           ]),
           label='windowed')
 
-  @retry(stop=stop_after_attempt(3))
   def test_flattened_side_input(self, with_transcoding=True):
     with self.create_pipeline() as p:
       main = p | 'main' >> beam.Create([None])
@@ -444,7 +435,6 @@ class FnApiRunnerTest(unittest.TestCase):
                   equal_to([('a', 1), ('b', 2)] + third_element),
                   label='CheckFlattenOfSideInput')
 
-  @retry(stop=stop_after_attempt(3))
   def test_gbk_side_input(self):
     with self.create_pipeline() as p:
       main = p | 'main' >> beam.Create([None])
@@ -455,7 +445,6 @@ class FnApiRunnerTest(unittest.TestCase):
               'a': [1]
           })]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_multimap_side_input(self):
     with self.create_pipeline() as p:
       main = p | 'main' >> beam.Create(['a', 'b'])
@@ -465,7 +454,6 @@ class FnApiRunnerTest(unittest.TestCase):
               lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
           equal_to([('a', [1, 3]), ('b', [2])]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_multimap_multiside_input(self):
     # A test where two transforms in the same stage consume the same PCollection
     # twice as side input.
@@ -487,7 +475,6 @@ class FnApiRunnerTest(unittest.TestCase):
               beam.pvalue.AsList(side)),
           equal_to([('a', [1, 3], [1, 2, 3]), ('b', [2], [1, 2, 3])]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_multimap_side_input_type_coercion(self):
     with self.create_pipeline() as p:
       main = p | 'main' >> beam.Create(['a', 'b'])
@@ -502,7 +489,6 @@ class FnApiRunnerTest(unittest.TestCase):
               lambda k, d: (k, sorted(d[k])), beam.pvalue.AsMultiMap(side)),
           equal_to([('a', [1, 3]), ('b', [2])]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_unfusable_side_inputs(self):
     def cross_product(elem, sides):
       for side in sides:
@@ -529,7 +515,6 @@ class FnApiRunnerTest(unittest.TestCase):
           pcoll | beam.FlatMap(cross_product, beam.pvalue.AsList(derived)),
           equal_to([('a', 'a'), ('a', 'b'), ('b', 'a'), ('b', 'b')]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_state_only(self):
     index_state_spec = userstate.CombiningValueStateSpec('index', sum)
     value_and_index_state_spec = userstate.ReadModifyWriteStateSpec(
@@ -557,7 +542,6 @@ class FnApiRunnerTest(unittest.TestCase):
           p | beam.Create(inputs) | beam.ParDo(AddIndex()), equal_to(expected))
 
   @unittest.skip('TestStream not yet supported')
-  @retry(stop=stop_after_attempt(3))
   def test_teststream_pardo_timers(self):
     timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
 
@@ -587,7 +571,6 @@ class FnApiRunnerTest(unittest.TestCase):
       #expected = [('fired', ts) for ts in (20, 200)]
       #assert_that(actual, equal_to(expected))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_timers(self):
     timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
     state_spec = userstate.CombiningValueStateSpec('num_called', sum)
@@ -619,7 +602,6 @@ class FnApiRunnerTest(unittest.TestCase):
       expected = [('fired', ts) for ts in (20, 200, 40, 400)]
       assert_that(actual, equal_to(expected))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_timers_clear(self):
     timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
     clear_timer_spec = userstate.TimerSpec(
@@ -655,15 +637,12 @@ class FnApiRunnerTest(unittest.TestCase):
       expected = [('fired', ts) for ts in (20, 200)]
       assert_that(actual, equal_to(expected))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_state_timers(self):
     self._run_pardo_state_timers(windowed=False)
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_state_timers_non_standard_coder(self):
     self._run_pardo_state_timers(windowed=False, key_type=Any)
 
-  @retry(stop=stop_after_attempt(3))
   def test_windowed_pardo_state_timers(self):
     self._run_pardo_state_timers(windowed=True)
 
@@ -732,7 +711,6 @@ class FnApiRunnerTest(unittest.TestCase):
 
       assert_that(actual, is_buffered_correctly)
 
-  @retry(stop=stop_after_attempt(3))
   def test_pardo_dynamic_timer(self):
     class DynamicTimerDoFn(beam.DoFn):
       dynamic_timer_spec = userstate.TimerSpec(
@@ -757,7 +735,6 @@ class FnApiRunnerTest(unittest.TestCase):
           | beam.ParDo(DynamicTimerDoFn()))
       assert_that(actual, equal_to([('key1', 10), ('key2', 20), ('key3', 30)]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf(self):
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
@@ -776,7 +753,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_dofn_as_restriction_provider(self):
     class ExpandingStringsDoFn(beam.DoFn, ExpandStringsProvider):
       def process(
@@ -792,7 +768,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_check_done_failed(self):
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
@@ -812,7 +787,6 @@ class FnApiRunnerTest(unittest.TestCase):
         data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
         _ = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_watermark_tracking(self):
     class ExpandingStringsDoFn(beam.DoFn):
       def process(
@@ -839,7 +813,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn()))
       assert_that(actual, equal_to(list(''.join(data))))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_dofn_as_watermark_estimator(self):
     class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider):
       def initial_estimator_state(self, element, restriction):
@@ -904,15 +877,12 @@ class FnApiRunnerTest(unittest.TestCase):
       self.assertEqual(1, len(counters))
       self.assertEqual(counters[0].committed, len(''.join(data)))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_sdf_initiated_checkpointing(self):
     self.run_sdf_initiated_checkpointing(is_drain=False)
 
-  @retry(stop=stop_after_attempt(3))
   def test_draining_sdf_with_sdf_initiated_checkpointing(self):
     self.run_sdf_initiated_checkpointing(is_drain=True)
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_default_truncate_when_bounded(self):
     class SimleSDF(beam.DoFn):
       def process(
@@ -930,7 +900,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
       assert_that(actual, equal_to(range(10)))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_default_truncate_when_unbounded(self):
     class SimleSDF(beam.DoFn):
       def process(
@@ -948,7 +917,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
       assert_that(actual, equal_to([]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_with_truncate(self):
     class SimleSDF(beam.DoFn):
       def process(
@@ -966,7 +934,6 @@ class FnApiRunnerTest(unittest.TestCase):
       actual = (p | beam.Create([10]) | beam.ParDo(SimleSDF()))
       assert_that(actual, equal_to(range(5)))
 
-  @retry(stop=stop_after_attempt(3))
   def test_group_by_key(self):
     with self.create_pipeline() as p:
       res = (
@@ -977,13 +944,11 @@ class FnApiRunnerTest(unittest.TestCase):
       assert_that(res, equal_to([('a', [1, 2]), ('b', [3])]))
 
   # Runners may special case the Reshuffle transform urn.
-  @retry(stop=stop_after_attempt(3))
   def test_reshuffle(self):
     with self.create_pipeline() as p:
       assert_that(
           p | beam.Create([1, 2, 3]) | beam.Reshuffle(), equal_to([1, 2, 3]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_flatten(self, with_transcoding=True):
     with self.create_pipeline() as p:
       if with_transcoding:
@@ -997,13 +962,11 @@ class FnApiRunnerTest(unittest.TestCase):
           p | 'd' >> beam.Create(additional)) | beam.Flatten()
       assert_that(res, equal_to(['a', 'b', 'c'] + additional))
 
-  @retry(stop=stop_after_attempt(3))
   def test_flatten_same_pcollections(self, with_transcoding=True):
     with self.create_pipeline() as p:
       pc = p | beam.Create(['a', 'b'])
       assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3))
 
-  @retry(stop=stop_after_attempt(3))
   def test_combine_per_key(self):
     with self.create_pipeline() as p:
       res = (
@@ -1012,7 +975,6 @@ class FnApiRunnerTest(unittest.TestCase):
           | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
       assert_that(res, equal_to([('a', 1.5), ('b', 3.0)]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_read(self):
     # Can't use NamedTemporaryFile as a context
     # due to https://bugs.python.org/issue14243
@@ -1026,7 +988,6 @@ class FnApiRunnerTest(unittest.TestCase):
     finally:
       os.unlink(temp_file.name)
 
-  @retry(stop=stop_after_attempt(3))
   def test_windowing(self):
     with self.create_pipeline() as p:
       res = (
@@ -1038,7 +999,6 @@ class FnApiRunnerTest(unittest.TestCase):
           | beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
       assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_custom_merging_window(self):
     with self.create_pipeline() as p:
       res = (
@@ -1055,7 +1015,6 @@ class FnApiRunnerTest(unittest.TestCase):
     self.assertEqual(GenericMergingWindowFn._HANDLES, {})
 
   @unittest.skip('BEAM-9119: test is flaky')
-  @retry(stop=stop_after_attempt(3))
   def test_large_elements(self):
     with self.create_pipeline() as p:
       big = (
@@ -1078,7 +1037,6 @@ class FnApiRunnerTest(unittest.TestCase):
       gbk_res = (big | beam.GroupByKey() | beam.Map(lambda x: x[0]))
       assert_that(gbk_res, equal_to(['a', 'b']), label='gbk')
 
-  @retry(stop=stop_after_attempt(3))
   def test_error_message_includes_stage(self):
     with self.assertRaises(BaseException) as e_cm:
       with self.create_pipeline() as p:
@@ -1099,7 +1057,6 @@ class FnApiRunnerTest(unittest.TestCase):
     self.assertIn('StageC', message)
     self.assertNotIn('StageB', message)
 
-  @retry(stop=stop_after_attempt(3))
   def test_error_traceback_includes_user_code(self):
     def first(x):
       return second(x)
@@ -1123,7 +1080,6 @@ class FnApiRunnerTest(unittest.TestCase):
     self.assertIn('second', message)
     self.assertIn('third', message)
 
-  @retry(stop=stop_after_attempt(3))
   def test_no_subtransform_composite(self):
     class First(beam.PTransform):
       def expand(self, pcolls):
@@ -1134,7 +1090,6 @@ class FnApiRunnerTest(unittest.TestCase):
       pcoll_b = p | 'b' >> beam.Create(['b'])
       assert_that((pcoll_a, pcoll_b) | First(), equal_to(['a']))
 
-  @retry(stop=stop_after_attempt(3))
   def test_metrics(self, check_gauge=True):
     p = self.create_pipeline()
 
@@ -1167,7 +1122,6 @@ class FnApiRunnerTest(unittest.TestCase):
                                   .with_name('gauge'))['gauges']
       self.assertEqual(gaug.committed.value, 3)
 
-  @retry(stop=stop_after_attempt(3))
   def test_callbacks_with_exception(self):
     elements_list = ['1', '2']
 
@@ -1187,7 +1141,6 @@ class FnApiRunnerTest(unittest.TestCase):
           | beam.ParDo(FinalizebleDoFnWithException()))
       assert_that(res, equal_to(['1', '2']))
 
-  @retry(stop=stop_after_attempt(3))
   def test_register_finalizations(self):
     event_recorder = EventRecorder(tempfile.gettempdir())
 
@@ -1225,7 +1178,6 @@ class FnApiRunnerTest(unittest.TestCase):
 
     event_recorder.cleanup()
 
-  @retry(stop=stop_after_attempt(3))
   def test_sdf_synthetic_source(self):
     common_attrs = {
         'key_size': 1,
@@ -1256,7 +1208,6 @@ class FnApiRunnerTest(unittest.TestCase):
           | beam.combiners.Count.Globally())
       assert_that(res, equal_to([total_num_records]))
 
-  @retry(stop=stop_after_attempt(3))
   def test_create_value_provider_pipeline_option(self):
     # Verify that the runner can execute a pipeline when there are value
     # provider pipeline options
@@ -1272,7 +1223,6 @@ class FnApiRunnerTest(unittest.TestCase):
     with self.create_pipeline() as p:
       assert_that(p | beam.Create(['a', 'b']), equal_to(['a', 'b']))
 
-  @retry(stop=stop_after_attempt(3))
   def _test_pack_combiners(self, assert_using_counter_names):
     counter = beam.metrics.Metrics.counter('ns', 'num_values')
 
@@ -1319,7 +1269,6 @@ class FnApiRunnerTest(unittest.TestCase):
         self.assertTrue(
             any(re.match(packed_step_name_regex, s) for s in step_names))
 
-  @retry(stop=stop_after_attempt(3))
   def test_pack_combiners(self):
     self._test_pack_combiners(assert_using_counter_names=True)
 
@@ -1404,7 +1353,6 @@ class FnApiRunnerMetricsTest(unittest.TestCase):
   def create_pipeline(self):
     return beam.Pipeline(runner=fn_api_runner.FnApiRunner())
 
-  @retry(stop=stop_after_attempt(3))
   def test_element_count_metrics(self):
     class GenerateTwoOutputs(beam.DoFn):
       def process(self, element):
@@ -1589,7 +1537,6 @@ class FnApiRunnerMetricsTest(unittest.TestCase):
       print(res._monitoring_infos_by_stage)
       raise
 
-  @retry(stop=stop_after_attempt(3))
   def test_non_user_metrics(self):
     p = self.create_pipeline()
 
diff --git a/sdks/python/apache_beam/runners/worker/data_plane.py b/sdks/python/apache_beam/runners/worker/data_plane.py
index 4baca681d9e..2509eb8d3e8 100644
--- a/sdks/python/apache_beam/runners/worker/data_plane.py
+++ b/sdks/python/apache_beam/runners/worker/data_plane.py
@@ -22,6 +22,7 @@
 
 import abc
 import collections
+import json
 import logging
 import queue
 import threading
@@ -51,6 +52,7 @@ from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
 
 if TYPE_CHECKING:
   import apache_beam.coders.slow_stream
+
   OutputStream = apache_beam.coders.slow_stream.OutputStream
   DataOrTimers = Union[beam_fn_api_pb2.Elements.Data,
                        beam_fn_api_pb2.Elements.Timers]
@@ -68,6 +70,22 @@ _DEFAULT_TIME_FLUSH_THRESHOLD_MS = 0  # disable time-based flush by default
 # can have up to _MAX_CLEANED_INSTRUCTIONS items. See _GrpcDataChannel.
 _MAX_CLEANED_INSTRUCTIONS = 10000
 
+# retry on transient UNAVAILABLE grpc error from data channels.
+_GRPC_SERVICE_CONFIG = json.dumps({
+    "methodConfig": [{
+        "name": [{
+            "service": "org.apache.beam.model.fn_execution.v1.BeamFnData"
+        }],
+        "retryPolicy": {
+            "maxAttempts": 5,
+            "initialBackoff": "0.1s",
+            "maxBackoff": "5s",
+            "backoffMultiplier": 2,
+            "retryableStatusCodes": ["UNAVAILABLE"],
+        },
+    }]
+})
+
 
 class ClosableOutputStream(OutputStream):
   """A Outputstream for use with CoderImpls that has a close() method."""
@@ -111,6 +129,7 @@ class ClosableOutputStream(OutputStream):
 
 class SizeBasedBufferingClosableOutputStream(ClosableOutputStream):
   """A size-based buffering OutputStream."""
+
   def __init__(
       self,
       close_callback=None,  # type: Optional[Callable[[bytes], None]]
@@ -185,6 +204,7 @@ class TimeBasedBufferingClosableOutputStream(
 
 class PeriodicThread(threading.Thread):
   """Call a function periodically with the specified number of seconds"""
+
   def __init__(
       self,
       interval,  # type: float
@@ -656,6 +676,7 @@ class _GrpcDataChannel(DataChannel):
 
 class GrpcClientDataChannel(_GrpcDataChannel):
   """A DataChannel wrapping the client side of a BeamFnData connection."""
+
   def __init__(
       self,
       data_stub,  # type: beam_fn_api_pb2_grpc.BeamFnDataStub
@@ -724,6 +745,7 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
 
   Caches the created channels by ``data descriptor url``.
   """
+
   def __init__(
       self,
       credentials=None,  # type: Any
@@ -752,7 +774,8 @@ class GrpcClientDataChannelFactory(DataChannelFactory):
           # received or sent over the data plane. The actual buffer size
           # is controlled in a layer above.
           channel_options = [("grpc.max_receive_message_length", -1),
-                             ("grpc.max_send_message_length", -1)]
+                             ("grpc.max_send_message_length", -1),
+                             ("grpc.service_config", _GRPC_SERVICE_CONFIG)]
           grpc_channel = None
           if self._credentials is None:
             grpc_channel = GRPCChannelFactory.insecure_channel(
diff --git a/sdks/python/apache_beam/runners/worker/sdk_worker.py b/sdks/python/apache_beam/runners/worker/sdk_worker.py
index 38e60f86ceb..562c3139739 100644
--- a/sdks/python/apache_beam/runners/worker/sdk_worker.py
+++ b/sdks/python/apache_beam/runners/worker/sdk_worker.py
@@ -24,6 +24,7 @@ import abc
 import collections
 import contextlib
 import functools
+import json
 import logging
 import queue
 import sys
@@ -85,6 +86,22 @@ MAX_KNOWN_NOT_RUNNING_INSTRUCTIONS = 1000
 # will remember for failed instructions.
 MAX_FAILED_INSTRUCTIONS = 10000
 
+# retry on transient UNAVAILABLE grpc error from state channels.
+_GRPC_SERVICE_CONFIG = json.dumps({
+    "methodConfig": [{
+        "name": [{
+            "service": "org.apache.beam.model.fn_execution.v1.BeamFnState"
+        }],
+        "retryPolicy": {
+            "maxAttempts": 5,
+            "initialBackoff": "0.1s",
+            "maxBackoff": "5s",
+            "backoffMultiplier": 2,
+            "retryableStatusCodes": ["UNAVAILABLE"],
+        },
+    }]
+})
+
 
 class ShortIdCache(object):
   """ Cache for MonitoringInfo "short ids"
@@ -835,7 +852,8 @@ class GrpcStateHandlerFactory(StateHandlerFactory):
           # received or sent over the data plane. The actual buffer size is
           # controlled in a layer above.
           options = [('grpc.max_receive_message_length', -1),
-                     ('grpc.max_send_message_length', -1)]
+                     ('grpc.max_send_message_length', -1),
+                     ('grpc.service_config', _GRPC_SERVICE_CONFIG)]
           if self._credentials is None:
             _LOGGER.info('Creating insecure state channel for %s.', url)
             grpc_channel = GRPCChannelFactory.insecure_channel(