This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new b1c9d8aec07 Optimize to use cached output receiver instead of creating 
one on DoFn invocation #21250 (#25245)
b1c9d8aec07 is described below

commit b1c9d8aec07ce72e946bd349eb4345417efbfc9c
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Feb 3 13:41:24 2023 -0800

    Optimize to use cached output receiver instead of creating one on DoFn 
invocation #21250 (#25245)
    
    This shows up whenever transforms use output receivers. For example 
map/flatmap where the calls are expected to be really inexpensive so we don't 
want to take on the overhead of creating an object.
    
    We saw a small performance improvement overall but best overall was that we 
reduced the size of the stack by 1 in these scenarios.
    
    Before:
    ```
    Benchmark                                        Mode  Cnt      Score     
Error  Units
    ProcessBundleBenchmark.testLargeBundle          thrpt   15   3147.619 ± 
130.414  ops/s
    ```
    
    After:
    ```
    Benchmark                                        Mode  Cnt      Score     
Error  Units
    ProcessBundleBenchmark.testLargeBundle          thrpt   15   3251.226 ± 
138.822  ops/s
    ```
---
 .../beam/sdk/transforms/DoFnOutputReceivers.java   |   2 +-
 .../apache/beam/fn/harness/FnApiDoFnRunner.java    | 310 +++++++++++++++++++--
 2 files changed, 295 insertions(+), 17 deletions(-)

diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
index a17264da35d..27fbb9754ec 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnOutputReceivers.java
@@ -115,7 +115,7 @@ public class DoFnOutputReceivers {
       checkState(outputCoder != null, "No output tag for " + tag);
       checkState(
           outputCoder instanceof SchemaCoder,
-          "Output with tag " + tag + " must have a schema in order to call " + 
" getRowReceiver");
+          "Output with tag " + tag + " must have a schema in order to call 
getRowReceiver");
       return DoFnOutputReceivers.rowReceiver(context, tag, (SchemaCoder<T>) 
outputCoder);
     }
   }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 2b449e0200b..13d85d27006 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -85,7 +85,6 @@ import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
-import org.apache.beam.sdk.transforms.DoFnOutputReceivers;
 import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
@@ -2413,7 +2412,7 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
 
   /** Base implementation that does not override methods which need to be 
window aware. */
   private abstract class ProcessBundleContextBase extends DoFn<InputT, 
OutputT>.ProcessContext
-      implements DoFnInvoker.ArgumentProvider<InputT, OutputT> {
+      implements DoFnInvoker.ArgumentProvider<InputT, OutputT>, 
OutputReceiver<OutputT> {
 
     private ProcessBundleContextBase() {
       doFn.super();
@@ -2478,17 +2477,112 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(this, null);
+      return this;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction 
=
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    fromRowFunction.apply(output), timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(this, null, 
mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use 
across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> 
taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> 
taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) ProcessBundleContextBase.this;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, output, currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, output, currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to 
call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                ProcessBundleContextBase.this.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, 
this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, 
this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> 
doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(this, outputCoders);
+      return taggedOutputReceiver;
     }
 
     @Override
@@ -2563,7 +2657,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
    * DoFn.OnWindowExpiration @OnWindowExpiration}.
    */
   private class OnWindowExpirationContext<K> extends 
BaseArgumentProvider<InputT, OutputT> {
-    private class Context extends DoFn<InputT, 
OutputT>.OnWindowExpirationContext {
+    private class Context extends DoFn<InputT, 
OutputT>.OnWindowExpirationContext
+        implements OutputReceiver<OutputT> {
       private Context() {
         doFn.super();
       }
@@ -2671,17 +2766,108 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(context, null);
+      return context;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction 
=
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(fromRowFunction.apply(output), 
timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(context, null, 
mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use 
across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> 
taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> 
taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) context;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                context.outputWithTimestamp(tag, output, 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                context.outputWithTimestamp(tag, output, 
currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to 
call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(tag, 
fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, 
this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, 
this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> 
doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(context);
+      return taggedOutputReceiver;
     }
 
     @Override
@@ -2716,7 +2902,8 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
   /** Provides arguments for a {@link DoFnInvoker} for {@link DoFn.OnTimer 
@OnTimer}. */
   private class OnTimerContext<K> extends BaseArgumentProvider<InputT, 
OutputT> {
 
-    private class Context extends DoFn<InputT, OutputT>.OnTimerContext {
+    private class Context extends DoFn<InputT, OutputT>.OnTimerContext
+        implements OutputReceiver<OutputT> {
       private Context() {
         doFn.super();
       }
@@ -2840,17 +3027,108 @@ public class FnApiDoFnRunner<InputT, RestrictionT, 
PositionT, WatermarkEstimator
 
     @Override
     public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.windowedReceiver(context, null);
+      return context;
     }
 
+    private final OutputReceiver<Row> mainRowOutputReceiver =
+        mainOutputSchemaCoder == null
+            ? null
+            : new OutputReceiver<Row>() {
+              private final SerializableFunction<Row, OutputT> fromRowFunction 
=
+                  mainOutputSchemaCoder.getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(fromRowFunction.apply(output), 
timestamp);
+              }
+            };
+
     @Override
     public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
-      return DoFnOutputReceivers.rowReceiver(context, null, 
mainOutputSchemaCoder);
-    }
+      checkState(
+          mainOutputSchemaCoder != null,
+          "Output with tag "
+              + mainOutputTag
+              + " must have a schema in order to call getRowReceiver");
+      return mainRowOutputReceiver;
+    }
+
+    /** A {@link MultiOutputReceiver} which caches created instances to re-use 
across bundles. */
+    private final MultiOutputReceiver taggedOutputReceiver =
+        new MultiOutputReceiver() {
+          private final Map<TupleTag<?>, OutputReceiver<?>> 
taggedOutputReceivers = new HashMap<>();
+          private final Map<TupleTag<?>, OutputReceiver<Row>> 
taggedRowReceivers = new HashMap<>();
+
+          private <T> OutputReceiver<T> createTaggedOutputReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              return (OutputReceiver<T>) context;
+            }
+            return new OutputReceiver<T>() {
+              @Override
+              public void output(T output) {
+                context.outputWithTimestamp(tag, output, 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(T output, Instant timestamp) {
+                context.outputWithTimestamp(tag, output, 
currentElement.getTimestamp());
+              }
+            };
+          }
+
+          private <T> OutputReceiver<Row> createTaggedRowReceiver(TupleTag<T> 
tag) {
+            if (tag == null || mainOutputTag.equals(tag)) {
+              checkState(
+                  mainOutputSchemaCoder != null,
+                  "Output with tag "
+                      + mainOutputTag
+                      + " must have a schema in order to call getRowReceiver");
+              return mainRowOutputReceiver;
+            }
+
+            Coder<T> outputCoder = (Coder<T>) outputCoders.get(tag);
+            checkState(outputCoder != null, "No output tag for " + tag);
+            checkState(
+                outputCoder instanceof SchemaCoder,
+                "Output with tag " + tag + " must have a schema in order to 
call getRowReceiver");
+            return new OutputReceiver<Row>() {
+              private SerializableFunction<Row, T> fromRowFunction =
+                  ((SchemaCoder) outputCoder).getFromRowFunction();
+
+              @Override
+              public void output(Row output) {
+                context.outputWithTimestamp(
+                    tag, fromRowFunction.apply(output), 
currentElement.getTimestamp());
+              }
+
+              @Override
+              public void outputWithTimestamp(Row output, Instant timestamp) {
+                context.outputWithTimestamp(tag, 
fromRowFunction.apply(output), timestamp);
+              }
+            };
+          }
+
+          @Override
+          public <T> OutputReceiver<T> get(TupleTag<T> tag) {
+            return (OutputReceiver<T>)
+                taggedOutputReceivers.computeIfAbsent(tag, 
this::createTaggedOutputReceiver);
+          }
+
+          @Override
+          public <T> OutputReceiver<Row> getRowReceiver(TupleTag<T> tag) {
+            return taggedRowReceivers.computeIfAbsent(tag, 
this::createTaggedRowReceiver);
+          }
+        };
 
     @Override
     public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> 
doFn) {
-      return DoFnOutputReceivers.windowedMultiReceiver(context);
+      return taggedOutputReceiver;
     }
 
     @Override

Reply via email to