You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2023/02/03 21:41:32 UTC

[beam] branch master updated: Optimize to use cached output receiver instead of creating one on DoFn invocation #21250 (#25245)

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b1c9d8aec07 Optimize to use cached output receiver instead of creating one on DoFn invocation #21250 (#25245)
b1c9d8aec07 is described below

commit b1c9d8aec07ce72e946bd349eb4345417efbfc9c
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Feb 3 13:41:24 2023 -0800

    Optimize to use cached output receiver instead of creating one on DoFn invocation #21250 (#25245)
    
    This shows up whenever transforms use output receivers. For example map/flatmap where the calls are expected to be really inexpensive so we don't want to take on the overhead of creating an object.
    
    We saw a small performance improvement overall but best overall was that we reduced the size of the stack by 1 in these scenarios.
    
    Before:
    ```
    Benchmark                                        Mode  Cnt      Score     Error  Units
    ProcessBundleBenchmark.testLargeBundle          thrpt   15   3147.619 ± 130.414  ops/s
    ```
    
    After:
    ```
    Benchmark                                        Mode  Cnt      Score     Error  Units
    ProcessBundleBenchmark.testLargeBundle          thrpt   15   3251.226 ± 138.822  ops/s
    ```
---
 .../beam/sdk/transforms/DoFnOutputReceivers.java   |   2 +-
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 310 +++++++++++++++++++--
 2 files changed, 295 insertions(+), 17 deletions(-)

diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
index a17264da35d..27fbb9754ec 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
@@ -115,7 +115,7 @@ public class DoFnOutputReceivers {
       checkState(outputCoder != null, "No output tag for " + tag);
       checkState(
           outputCoder instanceof SchemaCoder,
-          "Output with tag " + tag + " must have a schema in order to call " + " getRowReceiver");
+          "Output with tag " + tag + " must have a schema in order to call getRowReceiver");
       return DoFnOutputReceivers.rowReceiver(context, tag, (SchemaCoder<T>) outputCoder);
     }
   }
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 2b449e0200b..13d85d27006 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -85,7 +85,6 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
@@ -2413,7 +2412,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
 
   /** Base implementation that does not override methods which need to be window aware. */
   private abstract class ProcessBundleContextBase extends DoFn<InputT, OutputT>.ProcessContext
-      implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
+      implements DoFnInvoker.ArgumentProvider<InputT, OutputT>, OutputReceiver<OutputT> {
 
     private ProcessBundleContextBase() {
       doFn.super();
@@ -2478,17 +2477,112 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(this, null);
+      return this;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction =
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    fromRowFunction.apply(output), timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(this, null, mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) ProcessBundleContextBase.this;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, output, currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, output, currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(this, outputCoders);
+      return taggedOutputReceiver;
     }
 
     @Override
@@ -2563,7 +2657,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
    * DoFn.OnWindowExpiration @OnWindowExpiration}.
    */
   private class OnWindowExpirationContext<K> extends BaseArgumentProvider<InputT, OutputT> {
-    private class Context extends DoFn<InputT, OutputT>.OnWindowExpirationContext {
+    private class Context extends DoFn<InputT, OutputT>.OnWindowExpirationContext
+        implements OutputReceiver<OutputT> {
       private Context() {
         doFn.super();
       }
@@ -2671,17 +2766,108 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(context, null);
+      return context;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction =
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(fromRowFunction.apply(output), timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(context, null, mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) context;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                context.outputWithTimestamp(tag, output, currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                context.outputWithTimestamp(tag, output, currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(tag, fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(context);
+      return taggedOutputReceiver;
     }
 
     @Override
@@ -2716,7 +2902,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
   /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer @OnTimer}. */
   private class OnTimerContext<K> extends BaseArgumentProvider<InputT, OutputT> {
 
-    private class Context extends DoFn<InputT, OutputT>.OnTimerContext {
+    private class Context extends DoFn<InputT, OutputT>.OnTimerContext
+        implements OutputReceiver<OutputT> {
       private Context() {
         doFn.super();
       }
@@ -2840,17 +3027,108 @@ public class FnApiDoFnRunner<InputT, RestrictionT, PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(context, null);
+      return context;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction =
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(fromRowFunction.apply(output), timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(context, null, mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) context;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                context.outputWithTimestamp(tag, output, currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                context.outputWithTimestamp(tag, output, currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(tag, fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(context);
+      return taggedOutputReceiver;
     }
 
     @Override