This is an automated email from the ASF dual-hosted git repository.

mmack pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a4cf26fdc99 [Spark Dataset runner] Fix initialization of metrics 
accumulator on driver (fixes #24809) (#24810)
a4cf26fdc99 is described below

commit a4cf26fdc99b312fb2968303f1bb06a14920d55a
Author: Moritz Mack <[email protected]>
AuthorDate: Tue Jan 3 13:36:09 2023 +0100

    [Spark Dataset runner] Fix initialization of metrics accumulator on driver 
(fixes #24809) (#24810)
---
 .../SparkStructuredStreamingPipelineResult.java    |   9 +-
 .../SparkStructuredStreamingRunner.java            |  37 ++++---
 .../metrics/MetricsAccumulator.java                | 117 +++++++++++++++------
 .../MetricsContainerStepMapAccumulator.java        |  65 ------------
 .../metrics/SparkBeamMetric.java                   |  12 ++-
 .../metrics/SparkBeamMetricSource.java             |   4 +-
 .../metrics/SparkMetricsContainerStepMap.java      |  43 --------
 .../batch/DoFnPartitionIteratorFactory.java        |  26 +++--
 .../translation/batch/DoFnRunnerWithMetrics.java   |  31 +++---
 .../translation/batch/ParDoTranslatorBatch.java    |   5 +-
 10 files changed, 158 insertions(+), 191 deletions(-)

diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
index 0f636caa489..b490ff875c3 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingPipelineResult.java
@@ -37,13 +37,16 @@ import org.joda.time.Duration;
 public class SparkStructuredStreamingPipelineResult implements PipelineResult {
 
   private final Future<?> pipelineExecution;
+  private final MetricsAccumulator metrics;
   private @Nullable final Runnable onTerminalState;
-
   private PipelineResult.State state;
 
   SparkStructuredStreamingPipelineResult(
-      Future<?> pipelineExecution, @Nullable Runnable onTerminalState) {
+      Future<?> pipelineExecution,
+      MetricsAccumulator metrics,
+      @Nullable final Runnable onTerminalState) {
     this.pipelineExecution = pipelineExecution;
+    this.metrics = metrics;
     this.onTerminalState = onTerminalState;
     // pipelineExecution is expected to have started executing eagerly.
     this.state = State.RUNNING;
@@ -105,7 +108,7 @@ public class SparkStructuredStreamingPipelineResult 
implements PipelineResult {
 
   @Override
   public MetricResults metrics() {
-    return 
asAttemptedOnlyMetricResults(MetricsAccumulator.getInstance().value());
+    return asAttemptedOnlyMetricResults(metrics.value());
   }
 
   @Override
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
index bca1bdc2a2a..3b9f96cdb7e 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/SparkStructuredStreamingRunner.java
@@ -28,6 +28,7 @@ import javax.annotation.Nullable;
 import org.apache.beam.runners.core.construction.SplittableParDo;
 import 
org.apache.beam.runners.core.construction.graph.ProjectionPushdownOptimizer;
 import org.apache.beam.runners.core.metrics.MetricsPusher;
+import org.apache.beam.runners.core.metrics.NoOpMetricsSink;
 import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
 import 
org.apache.beam.runners.spark.structuredstreaming.metrics.SparkBeamMetricSource;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.EvaluationContext;
@@ -44,7 +45,6 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.options.PipelineOptionsValidator;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder;
-import org.apache.spark.SparkContext;
 import org.apache.spark.SparkEnv$;
 import org.apache.spark.metrics.MetricsSystem;
 import org.apache.spark.sql.SparkSession;
@@ -139,6 +139,7 @@ public final class SparkStructuredStreamingRunner
   @Override
   public SparkStructuredStreamingPipelineResult run(final Pipeline pipeline) {
     MetricsEnvironment.setMetricsSupported(true);
+    MetricsAccumulator.clear();
 
     LOG.info(
         "*** SparkStructuredStreamingRunner is based on spark structured 
streaming framework and is no more \n"
@@ -150,23 +151,21 @@ public final class SparkStructuredStreamingRunner
     checkArgument(!options.isStreaming(), "Streaming is not supported.");
 
     final SparkSession sparkSession = 
SparkSessionFactory.getOrCreateSession(options);
-    initMetrics(sparkSession.sparkContext());
+    final MetricsAccumulator metrics = 
MetricsAccumulator.getInstance(sparkSession);
 
     final Future<?> submissionFuture =
         runAsync(() -> translatePipeline(sparkSession, pipeline).evaluate());
 
     final SparkStructuredStreamingPipelineResult result =
         new SparkStructuredStreamingPipelineResult(
-            submissionFuture, stopSparkSession(sparkSession, 
options.getUseActiveSparkSession()));
+            submissionFuture,
+            metrics,
+            sparkStopFn(sparkSession, options.getUseActiveSparkSession()));
 
     if (options.getEnableSparkMetricSinks()) {
-      registerMetricsSource(options.getAppName());
+      registerMetricsSource(options.getAppName(), metrics);
     }
-
-    MetricsPusher metricsPusher =
-        new MetricsPusher(
-            MetricsAccumulator.getInstance().value(), 
options.as(MetricsOptions.class), result);
-    metricsPusher.start();
+    startMetricsPusher(result, metrics);
 
     if (options.getTestMode()) {
       result.waitUntilFinish();
@@ -195,19 +194,23 @@ public final class SparkStructuredStreamingRunner
     return pipelineTranslator.translate(pipeline, sparkSession, options);
   }
 
-  private void registerMetricsSource(String appName) {
+  private void registerMetricsSource(String appName, MetricsAccumulator 
metrics) {
     final MetricsSystem metricsSystem = 
SparkEnv$.MODULE$.get().metricsSystem();
-    final SparkBeamMetricSource metricsSource = new 
SparkBeamMetricSource(appName + ".Beam");
+    final SparkBeamMetricSource metricsSource =
+        new SparkBeamMetricSource(appName + ".Beam", metrics);
     // re-register the metrics in case of context re-use
     metricsSystem.removeSource(metricsSource);
     metricsSystem.registerSource(metricsSource);
   }
 
-  /** Init Metrics/Aggregators accumulators. This method is idempotent. */
-  private static void initMetrics(SparkContext sparkContext) {
-    // Clear and init metrics accumulators
-    MetricsAccumulator.clear();
-    MetricsAccumulator.init(sparkContext);
+  /** Start {@link MetricsPusher} if sink is set. */
+  private void startMetricsPusher(
+      SparkStructuredStreamingPipelineResult result, MetricsAccumulator 
metrics) {
+    MetricsOptions metricsOpts = options.as(MetricsOptions.class);
+    Class<?> metricsSink = metricsOpts.getMetricsSink();
+    if (metricsSink != null && !metricsSink.equals(NoOpMetricsSink.class)) {
+      new MetricsPusher(metrics.value(), metricsOpts, result).start();
+    }
   }
 
   private static Future<?> runAsync(Runnable task) {
@@ -222,7 +225,7 @@ public final class SparkStructuredStreamingRunner
     return future;
   }
 
-  private static @Nullable Runnable stopSparkSession(SparkSession session, 
boolean isProvided) {
+  private static @Nullable Runnable sparkStopFn(SparkSession session, boolean 
isProvided) {
     return !isProvided ? () -> session.stop() : null;
   }
 }
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java
index a07ce967422..6edddff5831 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsAccumulator.java
@@ -19,51 +19,86 @@ package 
org.apache.beam.runners.spark.structuredstreaming.metrics;
 
 import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
-import org.apache.spark.SparkContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.util.AccumulatorV2;
+import org.checkerframework.checker.nullness.qual.Nullable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * For resilience, {@link AccumulatorV2 Accumulators} are required to be 
wrapped in a Singleton.
+ * {@link AccumulatorV2} for Beam metrics captured in {@link 
MetricsContainerStepMap}.
  *
  * @see <a
- *     
href="https://spark.apache.org/docs/2.4.4/streaming-programming-guide.html#accumulators-broadcast-variables-and-checkpoints";>accumulatorsV2</a>
+ *     
href="https://spark.apache.org/docs/latest/streaming-programming-guide.html#accumulators-broadcast-variables-and-checkpoints";>accumulatorsV2</a>
  */
-@SuppressWarnings({
-  "nullness" // TODO(https://github.com/apache/beam/issues/20497)
-})
-public class MetricsAccumulator {
+public class MetricsAccumulator
+    extends AccumulatorV2<MetricsContainerStepMap, MetricsContainerStepMap> {
   private static final Logger LOG = 
LoggerFactory.getLogger(MetricsAccumulator.class);
-
+  private static final MetricsContainerStepMap EMPTY = new 
SparkMetricsContainerStepMap();
   private static final String ACCUMULATOR_NAME = "Beam.Metrics";
 
-  private static volatile MetricsContainerStepMapAccumulator instance = null;
+  private static volatile @Nullable MetricsAccumulator instance = null;
 
-  /** Init metrics accumulator if it has not been initiated. This method is 
idempotent. */
-  public static void init(SparkContext sparkContext) {
-    if (instance == null) {
-      synchronized (MetricsAccumulator.class) {
-        if (instance == null) {
-          MetricsContainerStepMap metricsContainerStepMap = new 
SparkMetricsContainerStepMap();
-          MetricsContainerStepMapAccumulator accumulator =
-              new MetricsContainerStepMapAccumulator(metricsContainerStepMap);
-          sparkContext.register(accumulator, ACCUMULATOR_NAME);
+  private MetricsContainerStepMap value;
 
-          instance = accumulator;
-        }
-      }
-      LOG.info("Instantiated metrics accumulator: {}", instance.value());
-    } else {
-      instance.reset();
-    }
+  public MetricsAccumulator() {
+    value = new SparkMetricsContainerStepMap();
+  }
+
+  private MetricsAccumulator(MetricsContainerStepMap value) {
+    this.value = value;
+  }
+
+  @Override
+  public boolean isZero() {
+    return value.equals(EMPTY);
+  }
+
+  @Override
+  public MetricsAccumulator copy() {
+    MetricsContainerStepMap newContainer = new SparkMetricsContainerStepMap();
+    newContainer.updateAll(value);
+    return new MetricsAccumulator(newContainer);
+  }
+
+  @Override
+  public void reset() {
+    value = new SparkMetricsContainerStepMap();
+  }
+
+  @Override
+  public void add(MetricsContainerStepMap other) {
+    value.updateAll(other);
+  }
+
+  @Override
+  public void merge(AccumulatorV2<MetricsContainerStepMap, 
MetricsContainerStepMap> other) {
+    value.updateAll(other.value());
+  }
+
+  @Override
+  public MetricsContainerStepMap value() {
+    return value;
   }
 
-  public static MetricsContainerStepMapAccumulator getInstance() {
-    if (instance == null) {
-      throw new IllegalStateException("Metrics accumulator has not been 
instantiated");
-    } else {
-      return instance;
+  /**
+   * Get the {@link MetricsAccumulator} on this driver. If there's no such 
accumulator yet, it will
+   * be created and registered using the provided {@link SparkSession}.
+   */
+  public static MetricsAccumulator getInstance(SparkSession session) {
+    MetricsAccumulator current = instance;
+    if (current != null) {
+      return current;
+    }
+    synchronized (MetricsAccumulator.class) {
+      MetricsAccumulator accumulator = instance;
+      if (accumulator == null) {
+        accumulator = new MetricsAccumulator();
+        session.sparkContext().register(accumulator, ACCUMULATOR_NAME);
+        instance = accumulator;
+        LOG.info("Instantiated metrics accumulator: {}", instance.value());
+      }
+      return accumulator;
     }
   }
 
@@ -73,4 +108,26 @@ public class MetricsAccumulator {
       instance = null;
     }
   }
+
+  /**
+   * Sole purpose of this class is to override {@link #toString()} of {@link
+   * MetricsContainerStepMap} in order to show meaningful metrics in Spark Web 
Interface.
+   */
+  private static class SparkMetricsContainerStepMap extends 
MetricsContainerStepMap {
+
+    @Override
+    public String toString() {
+      return asAttemptedOnlyMetricResults(this).toString();
+    }
+
+    @Override
+    public boolean equals(@Nullable Object o) {
+      return super.equals(o);
+    }
+
+    @Override
+    public int hashCode() {
+      return super.hashCode();
+    }
+  }
 }
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsContainerStepMapAccumulator.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsContainerStepMapAccumulator.java
deleted file mode 100644
index 2d2a4ea1754..00000000000
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/MetricsContainerStepMapAccumulator.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.metrics;
-
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.apache.spark.util.AccumulatorV2;
-
-/** {@link AccumulatorV2} implementation for {@link MetricsContainerStepMap}. 
*/
-public class MetricsContainerStepMapAccumulator
-    extends AccumulatorV2<MetricsContainerStepMap, MetricsContainerStepMap> {
-  private static final MetricsContainerStepMap empty = new 
SparkMetricsContainerStepMap();
-
-  private MetricsContainerStepMap value;
-
-  public MetricsContainerStepMapAccumulator(MetricsContainerStepMap value) {
-    this.value = value;
-  }
-
-  @Override
-  public boolean isZero() {
-    return value.equals(empty);
-  }
-
-  @Override
-  public MetricsContainerStepMapAccumulator copy() {
-    MetricsContainerStepMap newContainer = new SparkMetricsContainerStepMap();
-    newContainer.updateAll(value);
-    return new MetricsContainerStepMapAccumulator(newContainer);
-  }
-
-  @Override
-  public void reset() {
-    this.value = new SparkMetricsContainerStepMap();
-  }
-
-  @Override
-  public void add(MetricsContainerStepMap other) {
-    this.value.updateAll(other);
-  }
-
-  @Override
-  public void merge(AccumulatorV2<MetricsContainerStepMap, 
MetricsContainerStepMap> other) {
-    this.value.updateAll(other.value());
-  }
-
-  @Override
-  public MetricsContainerStepMap value() {
-    return this.value;
-  }
-}
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
index 0cecae4a25b..1754ac4d167 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetric.java
@@ -28,6 +28,7 @@ import java.util.Map;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import javax.annotation.Nullable;
+import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.sdk.metrics.DistributionResult;
 import org.apache.beam.sdk.metrics.GaugeResult;
 import org.apache.beam.sdk.metrics.MetricKey;
@@ -40,17 +41,22 @@ import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Streams;
 
 /**
- * An adapter between the {@link SparkMetricsContainerStepMap} and the 
Dropwizard {@link Metric}
+ * An adapter between the {@link MetricsContainerStepMap} and the Dropwizard 
{@link Metric}
  * interface.
  */
 class SparkBeamMetric extends BeamMetricSet {
 
   private static final String ILLEGAL_CHARACTERS = "[^A-Za-z0-9-]";
 
+  private final MetricsAccumulator metrics;
+
+  SparkBeamMetric(MetricsAccumulator metrics) {
+    this.metrics = metrics;
+  }
+
   @Override
   public Map<String, Gauge<Double>> getValue(String prefix, MetricFilter 
filter) {
-    MetricResults metricResults =
-        asAttemptedOnlyMetricResults(MetricsAccumulator.getInstance().value());
+    MetricResults metricResults = 
asAttemptedOnlyMetricResults(metrics.value());
     Map<String, Gauge<Double>> metrics = new HashMap<>();
     MetricQueryResults allMetrics = metricResults.allMetrics();
     for (MetricResult<Long> metricResult : allMetrics.getCounters()) {
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java
index ed938ac8413..8a1e980ae0c 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkBeamMetricSource.java
@@ -29,9 +29,9 @@ public class SparkBeamMetricSource implements Source {
 
   private final MetricRegistry metricRegistry = new MetricRegistry();
 
-  public SparkBeamMetricSource(final String name) {
+  public SparkBeamMetricSource(String name, MetricsAccumulator metrics) {
     this.name = name;
-    metricRegistry.register(name, new SparkBeamMetric());
+    metricRegistry.register(name, new SparkBeamMetric(metrics));
   }
 
   @Override
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkMetricsContainerStepMap.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkMetricsContainerStepMap.java
deleted file mode 100644
index 533dceb42e2..00000000000
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/metrics/SparkMetricsContainerStepMap.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.beam.runners.spark.structuredstreaming.metrics;
-
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
-import org.checkerframework.checker.nullness.qual.Nullable;
-
-/**
- * Sole purpose of this class is to override {@link #toString()} of {@link 
MetricsContainerStepMap}
- * in order to show meaningful metrics in Spark Web Interface.
- */
-class SparkMetricsContainerStepMap extends MetricsContainerStepMap {
-
-  @Override
-  public String toString() {
-    return asAttemptedOnlyMetricResults(this).toString();
-  }
-
-  @Override
-  public boolean equals(@Nullable Object o) {
-    return super.equals(o);
-  }
-
-  @Override
-  public int hashCode() {
-    return super.hashCode();
-  }
-}
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
index 64a4f591ff7..df844cc9f11 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnPartitionIteratorFactory.java
@@ -72,11 +72,14 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
   protected final Map<String, PCollectionView<?>> sideInputs;
   protected final SideInputReader sideInputReader;
 
+  private final MetricsAccumulator metrics;
+
   private DoFnPartitionIteratorFactory(
       AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
FnOutT>> appliedPT,
       Supplier<PipelineOptions> options,
       PCollection<InT> input,
-      SideInputReader sideInputReader) {
+      SideInputReader sideInputReader,
+      MetricsAccumulator metrics) {
     this.stepName = appliedPT.getFullName();
     this.doFn = appliedPT.getTransform().getFn();
     this.doFnSchema = ParDoTranslation.getSchemaInformation(appliedPT);
@@ -88,6 +91,7 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, OutT 
extends @NonNull O
     this.outputCoders = outputCoders(appliedPT.getOutputs());
     this.sideInputs = appliedPT.getTransform().getSideInputs();
     this.sideInputReader = sideInputReader;
+    this.metrics = metrics;
   }
 
   /**
@@ -98,8 +102,9 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
       AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, OutT>> 
appliedPT,
       Supplier<PipelineOptions> options,
       PCollection<InT> input,
-      SideInputReader sideInputReader) {
-    return new SingleOut<>(appliedPT, options, input, sideInputReader);
+      SideInputReader sideInputReader,
+      MetricsAccumulator metrics) {
+    return new SingleOut<>(appliedPT, options, input, sideInputReader, 
metrics);
   }
 
   /**
@@ -113,8 +118,9 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
           Supplier<PipelineOptions> options,
           PCollection<InT> input,
           SideInputReader sideInputReader,
+          MetricsAccumulator metrics,
           Map<String, Integer> tagColIdx) {
-    return new MultiOut<>(appliedPT, options, input, sideInputReader, 
tagColIdx);
+    return new MultiOut<>(appliedPT, options, input, sideInputReader, metrics, 
tagColIdx);
   }
 
   @Override
@@ -137,8 +143,9 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
         AppliedPTransform<PCollection<? extends InT>, ?, MultiOutput<InT, 
OutT>> appliedPT,
         Supplier<PipelineOptions> options,
         PCollection<InT> input,
-        SideInputReader sideInputReader) {
-      super(appliedPT, options, input, sideInputReader);
+        SideInputReader sideInputReader,
+        MetricsAccumulator metrics) {
+      super(appliedPT, options, input, sideInputReader, metrics);
     }
 
     @Override
@@ -170,8 +177,9 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
         Supplier<PipelineOptions> options,
         PCollection<InT> input,
         SideInputReader sideInputReader,
+        MetricsAccumulator metrics,
         Map<String, Integer> tagColIdx) {
-      super(appliedPT, options, input, sideInputReader);
+      super(appliedPT, options, input, sideInputReader, metrics);
       this.tagColIdx = tagColIdx;
     }
 
@@ -246,7 +254,7 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
   private DoFnRunner<InT, FnOutT> simpleRunner(Deque<OutT> buffer) {
     return DoFnRunners.simpleRunner(
         options.get(),
-        (DoFn<InT, FnOutT>) doFn,
+        doFn,
         CachedSideInputReader.of(sideInputReader, sideInputs.values()),
         outputManager(buffer),
         mainOutput,
@@ -260,7 +268,7 @@ abstract class DoFnPartitionIteratorFactory<InT, FnOutT, 
OutT extends @NonNull O
   }
 
   private DoFnRunner<InT, FnOutT> metricsRunner(DoFnRunner<InT, FnOutT> 
runner) {
-    return new DoFnRunnerWithMetrics<>(stepName, runner, 
MetricsAccumulator.getInstance());
+    return new DoFnRunnerWithMetrics<>(stepName, runner, metrics);
   }
 
   private static Map<TupleTag<?>, Coder<?>> outputCoders(Map<TupleTag<?>, 
PCollection<?>> outputs) {
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
index b80ec87d3c5..f6b98a61e3d 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/DoFnRunnerWithMetrics.java
@@ -20,8 +20,7 @@ package 
org.apache.beam.runners.spark.structuredstreaming.translation.batch;
 import java.io.Closeable;
 import java.io.IOException;
 import org.apache.beam.runners.core.DoFnRunner;
-import org.apache.beam.runners.core.metrics.MetricsContainerImpl;
-import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsContainerStepMapAccumulator;
+import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
 import org.apache.beam.sdk.metrics.MetricsContainer;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.state.TimeDomain;
@@ -30,19 +29,19 @@ import 
org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.joda.time.Instant;
 
-/** DoFnRunner decorator which registers {@link MetricsContainerImpl}. */
+/** DoFnRunner decorator which registers {@link MetricsContainer}. */
 class DoFnRunnerWithMetrics<InputT, OutputT> implements DoFnRunner<InputT, 
OutputT> {
   private final DoFnRunner<InputT, OutputT> delegate;
-  private final String stepName;
-  private final MetricsContainerStepMapAccumulator metricsAccum;
+  private final MetricsContainer metrics;
 
   DoFnRunnerWithMetrics(
-      String stepName,
-      DoFnRunner<InputT, OutputT> delegate,
-      MetricsContainerStepMapAccumulator metricsAccum) {
+      String stepName, DoFnRunner<InputT, OutputT> delegate, 
MetricsAccumulator metricsAccum) {
+    this(delegate, metricsAccum.value().getContainer(stepName));
+  }
+
+  private DoFnRunnerWithMetrics(DoFnRunner<InputT, OutputT> delegate, 
MetricsContainer metrics) {
     this.delegate = delegate;
-    this.stepName = stepName;
-    this.metricsAccum = metricsAccum;
+    this.metrics = metrics;
   }
 
   @Override
@@ -52,7 +51,7 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements 
DoFnRunner<InputT, Outpu
 
   @Override
   public void startBundle() {
-    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
+    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metrics)) {
       delegate.startBundle();
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -61,7 +60,7 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements 
DoFnRunner<InputT, Outpu
 
   @Override
   public void processElement(final WindowedValue<InputT> elem) {
-    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
+    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metrics)) {
       delegate.processElement(elem);
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -77,7 +76,7 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements 
DoFnRunner<InputT, Outpu
       final Instant timestamp,
       final Instant outputTimestamp,
       final TimeDomain timeDomain) {
-    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
+    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metrics)) {
       delegate.onTimer(timerId, timerFamilyId, key, window, timestamp, 
outputTimestamp, timeDomain);
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -86,7 +85,7 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements 
DoFnRunner<InputT, Outpu
 
   @Override
   public void finishBundle() {
-    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metricsContainer())) {
+    try (Closeable ignored = 
MetricsEnvironment.scopedMetricsContainer(metrics)) {
       delegate.finishBundle();
     } catch (IOException e) {
       throw new RuntimeException(e);
@@ -97,8 +96,4 @@ class DoFnRunnerWithMetrics<InputT, OutputT> implements 
DoFnRunner<InputT, Outpu
   public <KeyT> void onWindowExpiration(BoundedWindow window, Instant 
timestamp, KeyT key) {
     delegate.onWindowExpiration(window, timestamp, key);
   }
-
-  private MetricsContainer metricsContainer() {
-    return metricsAccum.value().getContainer(stepName);
-  }
 }
diff --git 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
index 4d545e43813..44252237930 100644
--- 
a/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
+++ 
b/runners/spark/3/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/batch/ParDoTranslatorBatch.java
@@ -34,6 +34,7 @@ import java.util.Map.Entry;
 import org.apache.beam.runners.core.DoFnRunners;
 import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.runners.spark.SparkCommonPipelineOptions;
+import 
org.apache.beam.runners.spark.structuredstreaming.metrics.MetricsAccumulator;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.TransformTranslator;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SideInputValues;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.batch.functions.SparkSideInputReader;
@@ -114,6 +115,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
     Dataset<WindowedValue<InputT>> inputDs = cxt.getDataset(input);
     SideInputReader sideInputReader =
         createSideInputReader(transform.getSideInputs().values(), cxt);
+    MetricsAccumulator metrics = 
MetricsAccumulator.getInstance(cxt.getSparkSession());
 
     TupleTag<OutputT> mainOut = transform.getMainOutputTag();
     // Filter out unconsumed PCollections (except mainOut) to potentially 
avoid the costs of caching
@@ -135,6 +137,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
               cxt.getOptionsSupplier(),
               input,
               sideInputReader,
+              metrics,
               tagColIdx);
 
       // FIXME What's the strategy to unpersist Datasets / RDDs?
@@ -186,7 +189,7 @@ class ParDoTranslatorBatch<InputT, OutputT>
       PCollection<OutputT> output = cxt.getOutput(mainOut);
       DoFnPartitionIteratorFactory<InputT, ?, WindowedValue<OutputT>> 
doFnMapper =
           DoFnPartitionIteratorFactory.singleOutput(
-              cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, 
sideInputReader);
+              cxt.getCurrentTransform(), cxt.getOptionsSupplier(), input, 
sideInputReader, metrics);
 
       Dataset<WindowedValue<OutputT>> mainDS =
           inputDs.mapPartitions(doFnMapper, 
cxt.windowedEncoder(output.getCoder()));

Reply via email to