This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 679d30256c6 Swap setting a context from being on the hot path when we 
emit elements to only be done during bundle creation and teardown #21250 
(#25291)
679d30256c6 is described below

commit 679d30256c6bd64d9760702c667d7d355e70166b
Author: Luke Cwik <lc...@google.com>
AuthorDate: Fri Feb 3 15:39:54 2023 -0800

    Swap setting a context from being on the hot path when we emit elements to 
only be done during bundle creation and teardown #21250 (#25291)
    
    * Swap setting a context from being on the hot path when we emit elements 
to only be done during bundle creation and teardown.
    
    This cuts our per element processing overhead significantly down. For 
example in the ProcessBundleBenchmark.testLargeBundle we saw a reduction of 
about half of the execution state management overhead for each element (10.6% 
to 5.6% CPU overhead per element).
    
    The benchmark reflects this with a nice 11.5% performance improvement.
    
    Before:
    ```
    Result 
"org.apache.beam.fn.harness.jmh.ProcessBundleBenchmark.testLargeBundle":
      3265.141 ±(99.9%) 238.744 ops/s [Average]
      (min, avg, max) = (2952.616, 3265.141, 3525.482), stdev = 223.321
      CI (99.9%): [3026.398, 3503.885] (assumes normal distribution)
    ```
    
    After:
    ```
    Result 
"org.apache.beam.fn.harness.jmh.ProcessBundleBenchmark.testLargeBundle":
      3642.512 ±(99.9%) 45.865 ops/s [Average]
      (min, avg, max) = (3582.394, 3642.512, 3713.820), stdev = 42.902
      CI (99.9%): [3596.647, 3688.377] (assumes normal distribution)
    ```
---
 .../fn/harness/control/ExecutionStateSampler.java  | 109 +++++++++++++++-
 .../fn/harness/control/ProcessBundleHandler.java   |  57 +++-----
 .../harness/data/PCollectionConsumerRegistry.java  |  74 +++--------
 .../harness/data/PTransformFunctionRegistry.java   |  26 +---
 .../harness/control/ExecutionStateSamplerTest.java | 144 ++++++++++++++++++++
 .../harness/control/ProcessBundleHandlerTest.java  |  11 +-
 .../data/PCollectionConsumerRegistryTest.java      | 145 +++++++--------------
 .../data/PTransformFunctionRegistryTest.java       |  81 ++++++------
 8 files changed, 382 insertions(+), 265 deletions(-)

diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
index f528a4a9919..9a94ce05057 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java
@@ -32,10 +32,19 @@ import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
 import javax.annotation.concurrent.GuardedBy;
 import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor;
+import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
+import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Distribution;
+import org.apache.beam.sdk.metrics.Gauge;
+import org.apache.beam.sdk.metrics.Histogram;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.MetricsContainer;
 import org.apache.beam.sdk.options.ExecutorOptions;
 import org.apache.beam.sdk.options.ExperimentalOptions;
 import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.util.HistogramData;
 import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
 import org.checkerframework.checker.nullness.qual.Nullable;
@@ -163,9 +172,66 @@ public class ExecutionStateSampler {
     return new ExecutionStateTracker();
   }
 
+  /**
+   * A {@link MetricsContainer} that uses the current {@link ExecutionState} 
tracked by the provided
+   * {@link ExecutionStateTracker}.
+   */
+  private static class MetricsContainerForTracker implements MetricsContainer {
+    private final transient ExecutionStateTracker tracker;
+
+    private MetricsContainerForTracker(ExecutionStateTracker tracker) {
+      this.tracker = tracker;
+    }
+
+    @Override
+    public Counter getCounter(MetricName metricName) {
+      if (tracker.currentState != null) {
+        return tracker.currentState.metricsContainer.getCounter(metricName);
+      }
+      return 
tracker.metricsContainerRegistry.getUnboundContainer().getCounter(metricName);
+    }
+
+    @Override
+    public Distribution getDistribution(MetricName metricName) {
+      if (tracker.currentState != null) {
+        return 
tracker.currentState.metricsContainer.getDistribution(metricName);
+      }
+      return 
tracker.metricsContainerRegistry.getUnboundContainer().getDistribution(metricName);
+    }
+
+    @Override
+    public Gauge getGauge(MetricName metricName) {
+      if (tracker.currentState != null) {
+        return tracker.currentState.metricsContainer.getGauge(metricName);
+      }
+      return 
tracker.metricsContainerRegistry.getUnboundContainer().getGauge(metricName);
+    }
+
+    @Override
+    public Histogram getHistogram(MetricName metricName, 
HistogramData.BucketType bucketType) {
+      if (tracker.currentState != null) {
+        return tracker.currentState.metricsContainer.getHistogram(metricName, 
bucketType);
+      }
+      return tracker
+          .metricsContainerRegistry
+          .getUnboundContainer()
+          .getHistogram(metricName, bucketType);
+    }
+
+    @Override
+    public Iterable<MonitoringInfo> getMonitoringInfos() {
+      if (tracker.currentState != null) {
+        return tracker.currentState.metricsContainer.getMonitoringInfos();
+      }
+      return 
tracker.metricsContainerRegistry.getUnboundContainer().getMonitoringInfos();
+    }
+  }
+
   /** Tracks the current state of a single execution thread. */
   public class ExecutionStateTracker implements BundleProgressReporter {
-
+    // Used to create and store metrics containers for the execution states.
+    private final MetricsContainerStepMap metricsContainerRegistry;
+    private final MetricsContainer metricsContainer;
     // The set of execution states that this tracker is responsible for. 
Effectively
     // final since create() should not be invoked once any bundle starts 
processing.
     private final List<ExecutionStateImpl> executionStates;
@@ -187,13 +253,36 @@ public class ExecutionStateSampler {
     // Read and written by the ExecutionStateSampler thread
     private long transitionsAtLastSample;
 
+    // Ignore the @UnderInitialization for ExecutionStateTracker since it will 
be initialized by the
+    // time this method returns and no references are leaked to other threads 
during construction.
+    @SuppressWarnings({"assignment", "argument"})
     private ExecutionStateTracker() {
+      this.metricsContainerRegistry = new MetricsContainerStepMap();
       this.executionStates = new ArrayList<>();
       this.trackedThread = new AtomicReference<>();
       this.lastTransitionTime = new AtomicLong();
       this.numTransitionsLazy = new AtomicLong();
       this.currentStateLazy = new AtomicReference<>();
       this.processBundleId = new AtomicReference<>();
+      this.metricsContainer = new MetricsContainerForTracker(this);
+    }
+
+    /**
+     * Returns the {@link MetricsContainerStepMap} that is managed by this 
{@link
+     * ExecutionStateTracker}. This metrics container registry stores all the 
user counters
+     * associated for the current bundle execution.
+     */
+    public MetricsContainerStepMap getMetricsContainerRegistry() {
+      return metricsContainerRegistry;
+    }
+
+    /**
+     * Returns a {@link MetricsContainer} that delegates based upon the 
current execution state to
+     * the appropriate metrics container that is bound to the current {@link 
ExecutionState} or to
+     * the unbound {@link MetricsContainer} if no execution state is currently 
running.
+     */
+    public MetricsContainer getMetricsContainer() {
+      return metricsContainer;
     }
 
     /**
@@ -203,7 +292,12 @@ public class ExecutionStateSampler {
     public ExecutionState create(
         String shortId, String ptransformId, String ptransformUniqueName, 
String stateName) {
       ExecutionStateImpl newState =
-          new ExecutionStateImpl(shortId, ptransformId, ptransformUniqueName, 
stateName);
+          new ExecutionStateImpl(
+              shortId,
+              ptransformId,
+              ptransformUniqueName,
+              stateName,
+              metricsContainerRegistry.getContainer(ptransformId));
       executionStates.add(newState);
       return newState;
     }
@@ -285,10 +379,13 @@ public class ExecutionStateSampler {
 
     /** {@link ExecutionState} represents the current state of an execution 
thread. */
     private class ExecutionStateImpl implements ExecutionState {
+
       private final String shortId;
       private final String ptransformId;
       private final String ptransformUniqueName;
       private final String stateName;
+      private final MetricsContainer metricsContainer;
+
       // Read and written by the bundle processing thread frequently.
       private long msecs;
       // Read by the ExecutionStateSampler, written by the bundle processing 
thread frequently.
@@ -301,11 +398,16 @@ public class ExecutionStateSampler {
       private @Nullable ExecutionStateImpl previousState;
 
       private ExecutionStateImpl(
-          String shortId, String ptransformId, String ptransformName, String 
stateName) {
+          String shortId,
+          String ptransformId,
+          String ptransformName,
+          String stateName,
+          MetricsContainer metricsContainer) {
         this.shortId = shortId;
         this.ptransformId = ptransformId;
         this.ptransformUniqueName = ptransformName;
         this.stateName = stateName;
+        this.metricsContainer = metricsContainer;
         this.lazyMsecs = new AtomicLong();
       }
 
@@ -406,6 +508,7 @@ public class ExecutionStateSampler {
       this.numTransitions = 0;
       this.numTransitionsLazy.lazySet(0);
       this.lastTransitionTime.lazySet(0);
+      this.metricsContainerRegistry.reset();
     }
   }
 
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
index f2b95450391..560369a3907 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java
@@ -74,7 +74,6 @@ import 
org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy;
 import org.apache.beam.runners.core.construction.BeamUrns;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.Timer;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.fn.data.BeamFnDataInboundObserver;
@@ -517,7 +516,6 @@ public class ProcessBundleHandler {
       ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker();
 
       try (HandleStateCallsForBundle beamFnStateClient = 
bundleProcessor.getBeamFnStateClient()) {
-        bundleProcessor.getMetricsEnvironmentStateForBundle().start();
         stateTracker.start(request.getInstructionId());
         try {
           // Already in reverse topological order so we don't need to do 
anything.
@@ -688,7 +686,10 @@ public class ProcessBundleHandler {
     Map<String, ByteString> monitoringData = new HashMap<>();
     // Extract MonitoringInfos that come from the metrics container registry.
     monitoringData.putAll(
-        
bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds));
+        bundleProcessor
+            .getStateTracker()
+            .getMetricsContainerRegistry()
+            .getMonitoringData(shortIds));
     // Add any additional monitoring infos that the "runners" report 
explicitly.
     bundleProcessor
         .getBundleProgressReporterAndRegistrar()
@@ -701,7 +702,10 @@ public class ProcessBundleHandler {
     HashMap<String, ByteString> monitoringData = new HashMap<>();
     // Extract MonitoringInfos that come from the metrics container registry.
     monitoringData.putAll(
-        
bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds));
+        bundleProcessor
+            .getStateTracker()
+            .getMetricsContainerRegistry()
+            .getMonitoringData(shortIds));
     // Add any additional monitoring infos that the "runners" report 
explicitly.
     bundleProcessor
         .getBundleProgressReporterAndRegistrar()
@@ -736,21 +740,22 @@ public class ProcessBundleHandler {
   }
 
   @VisibleForTesting
-  static class MetricsEnvironmentStateForBundle implements 
MetricsEnvironmentState {
+  static class MetricsEnvironmentStateForBundle {
     private @Nullable MetricsEnvironmentState currentThreadState;
 
-    @Override
-    public @Nullable MetricsContainer activate(@Nullable MetricsContainer 
metricsContainer) {
-      return currentThreadState.activate(metricsContainer);
-    }
-
-    public void start() {
+    public void start(MetricsContainer container) {
       currentThreadState = 
MetricsEnvironment.getMetricsEnvironmentStateForCurrentThread();
+      currentThreadState.activate(container);
     }
 
     public void reset() {
+      currentThreadState.activate(null);
       currentThreadState = null;
     }
+
+    public void discard() {
+      currentThreadState.activate(null);
+    }
   }
 
   private BundleProcessor createBundleProcessor(
@@ -760,35 +765,19 @@ public class ProcessBundleHandler {
     SetMultimap<String, String> pCollectionIdsToConsumingPTransforms = 
HashMultimap.create();
     BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar =
         new BundleProgressReporter.InMemory();
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle =
         new MetricsEnvironmentStateForBundle();
     ExecutionStateTracker stateTracker = executionStateSampler.create();
     bundleProgressReporterAndRegistrar.register(stateTracker);
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            metricsEnvironmentStateForBundle,
-            stateTracker,
-            shortIds,
-            bundleProgressReporterAndRegistrar,
-            bundleDescriptor);
+            stateTracker, shortIds, bundleProgressReporterAndRegistrar, 
bundleDescriptor);
     HashSet<String> processedPTransformIds = new HashSet<>();
 
     PTransformFunctionRegistry startFunctionRegistry =
-        new PTransformFunctionRegistry(
-            metricsContainerRegistry,
-            metricsEnvironmentStateForBundle,
-            shortIds,
-            stateTracker,
-            Urns.START_BUNDLE_MSECS);
+        new PTransformFunctionRegistry(shortIds, stateTracker, 
Urns.START_BUNDLE_MSECS);
     PTransformFunctionRegistry finishFunctionRegistry =
-        new PTransformFunctionRegistry(
-            metricsContainerRegistry,
-            metricsEnvironmentStateForBundle,
-            shortIds,
-            stateTracker,
-            Urns.FINISH_BUNDLE_MSECS);
+        new PTransformFunctionRegistry(shortIds, stateTracker, 
Urns.FINISH_BUNDLE_MSECS);
     List<ThrowingRunnable> resetFunctions = new ArrayList<>();
     List<ThrowingRunnable> tearDownFunctions = new ArrayList<>();
 
@@ -835,7 +824,6 @@ public class ProcessBundleHandler {
             tearDownFunctions,
             splitListener,
             pCollectionConsumerRegistry,
-            metricsContainerRegistry,
             metricsEnvironmentStateForBundle,
             stateTracker,
             beamFnStateClient,
@@ -1028,7 +1016,6 @@ public class ProcessBundleHandler {
         List<ThrowingRunnable> tearDownFunctions,
         BundleSplitListener.InMemory splitListener,
         PCollectionConsumerRegistry pCollectionConsumerRegistry,
-        MetricsContainerStepMap metricsContainerRegistry,
         MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle,
         ExecutionStateTracker stateTracker,
         HandleStateCallsForBundle beamFnStateClient,
@@ -1044,7 +1031,6 @@ public class ProcessBundleHandler {
           tearDownFunctions,
           splitListener,
           pCollectionConsumerRegistry,
-          metricsContainerRegistry,
           metricsEnvironmentStateForBundle,
           stateTracker,
           beamFnStateClient,
@@ -1081,8 +1067,6 @@ public class ProcessBundleHandler {
 
     abstract PCollectionConsumerRegistry getpCollectionConsumerRegistry();
 
-    abstract MetricsContainerStepMap getMetricsContainerRegistry();
-
     abstract MetricsEnvironmentStateForBundle 
getMetricsEnvironmentStateForBundle();
 
     public abstract ExecutionStateTracker getStateTracker();
@@ -1140,6 +1124,7 @@ public class ProcessBundleHandler {
     synchronized void setupForProcessBundleRequest(InstructionRequest request) 
{
       this.instructionId = request.getInstructionId();
       this.cacheTokens = request.getProcessBundle().getCacheTokensList();
+      
getMetricsEnvironmentStateForBundle().start(getStateTracker().getMetricsContainer());
     }
 
     void reset() throws Exception {
@@ -1152,7 +1137,6 @@ public class ProcessBundleHandler {
         }
       }
       getSplitListener().clear();
-      getMetricsContainerRegistry().reset();
       getMetricsEnvironmentStateForBundle().reset();
       getStateTracker().reset();
       getBundleFinalizationCallbackRegistrations().clear();
@@ -1171,6 +1155,7 @@ public class ProcessBundleHandler {
         if (this.bundleCache != null) {
           this.bundleCache.clear();
         }
+        getMetricsEnvironmentStateForBundle().discard();
         for (BeamFnDataOutboundAggregator aggregator : 
getOutboundAggregators().values()) {
           aggregator.discard();
         }
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
index 556076ed3b1..45298a68d98 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java
@@ -35,7 +35,6 @@ import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.RehydratedComponents;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns;
@@ -46,8 +45,6 @@ import 
org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
 import org.apache.beam.sdk.metrics.Distribution;
-import org.apache.beam.sdk.metrics.MetricsContainer;
-import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
@@ -71,12 +68,9 @@ public class PCollectionConsumerRegistry {
   @SuppressWarnings({"rawtypes"})
   abstract static class ConsumerAndMetadata {
     public static ConsumerAndMetadata forConsumer(
-        FnDataReceiver consumer,
-        String pTransformId,
-        ExecutionState state,
-        MetricsContainer metricsContainer) {
+        FnDataReceiver consumer, String pTransformId, ExecutionState state) {
       return new AutoValue_PCollectionConsumerRegistry_ConsumerAndMetadata(
-          consumer, pTransformId, state, metricsContainer);
+          consumer, pTransformId, state);
     }
 
     public abstract FnDataReceiver getConsumer();
@@ -84,12 +78,8 @@ public class PCollectionConsumerRegistry {
     public abstract String getPTransformId();
 
     public abstract ExecutionState getExecutionState();
-
-    public abstract MetricsContainer getMetricsContainer();
   }
 
-  private final MetricsContainerStepMap metricsContainerRegistry;
-  private final MetricsEnvironmentState metricsEnvironmentState;
   private final ExecutionStateTracker stateTracker;
   private final ShortIdMap shortIdMap;
   private final Map<String, List<ConsumerAndMetadata>> 
pCollectionIdsToConsumers;
@@ -99,14 +89,10 @@ public class PCollectionConsumerRegistry {
   private final RehydratedComponents rehydratedComponents;
 
   public PCollectionConsumerRegistry(
-      MetricsContainerStepMap metricsContainerRegistry,
-      MetricsEnvironmentState metricsEnvironmentState,
       ExecutionStateTracker stateTracker,
       ShortIdMap shortIdMap,
       BundleProgressReporter.Registrar bundleProgressReporterRegistrar,
       ProcessBundleDescriptor processBundleDescriptor) {
-    this.metricsContainerRegistry = metricsContainerRegistry;
-    this.metricsEnvironmentState = metricsEnvironmentState;
     this.stateTracker = stateTracker;
     this.shortIdMap = shortIdMap;
     this.pCollectionIdsToConsumers = new HashMap<>();
@@ -176,11 +162,7 @@ public class PCollectionConsumerRegistry {
     List<ConsumerAndMetadata> consumerAndMetadatas =
         pCollectionIdsToConsumers.computeIfAbsent(pCollectionId, (unused) -> 
new ArrayList<>());
     consumerAndMetadatas.add(
-        ConsumerAndMetadata.forConsumer(
-            consumer,
-            pTransformId,
-            executionState,
-            metricsContainerRegistry.getContainer(pTransformId)));
+        ConsumerAndMetadata.forConsumer(consumer, pTransformId, 
executionState));
   }
 
   /**
@@ -219,18 +201,16 @@ public class PCollectionConsumerRegistry {
           if (consumerAndMetadatas.size() == 1) {
             ConsumerAndMetadata consumerAndMetadata = 
consumerAndMetadatas.get(0);
             if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) {
-              return new SplittingMetricTrackingFnDataReceiver(
-                  pcId, coder, consumerAndMetadata, metricsEnvironmentState);
+              return new SplittingMetricTrackingFnDataReceiver(pcId, coder, 
consumerAndMetadata);
             }
-            return new MetricTrackingFnDataReceiver(
-                pcId, coder, consumerAndMetadata, metricsEnvironmentState);
+            return new MetricTrackingFnDataReceiver(pcId, coder, 
consumerAndMetadata);
           } else {
             /* TODO(SDF), Consider supporting splitting each consumer 
individually. This would never
             come up in the existing SDF expansion, but might be useful to 
support fused SDF nodes.
             This would require dedicated delivery of the split results to each 
of the consumers
             separately. */
             return new MultiplexingMetricTrackingFnDataReceiver(
-                pcId, coder, ImmutableList.copyOf(consumerAndMetadatas), 
metricsEnvironmentState);
+                pcId, coder, ImmutableList.copyOf(consumerAndMetadatas));
           }
         });
   }
@@ -248,14 +228,9 @@ public class PCollectionConsumerRegistry {
     private final BundleCounter elementCountCounter;
     private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
     private final Coder<T> coder;
-    private final MetricsContainer metricsContainer;
-    private final MetricsEnvironmentState metricsEnvironmentState;
 
     public MetricTrackingFnDataReceiver(
-        String pCollectionId,
-        Coder<T> coder,
-        ConsumerAndMetadata consumerAndMetadata,
-        MetricsEnvironmentState metricsEnvironmentState) {
+        String pCollectionId, Coder<T> coder, ConsumerAndMetadata 
consumerAndMetadata) {
       this.delegate = consumerAndMetadata.getConsumer();
       this.executionState = consumerAndMetadata.getExecutionState();
 
@@ -291,8 +266,6 @@ public class PCollectionConsumerRegistry {
       
bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution);
 
       this.coder = coder;
-      this.metricsContainer = consumerAndMetadata.getMetricsContainer();
-      this.metricsEnvironmentState = metricsEnvironmentState;
     }
 
     @Override
@@ -303,17 +276,14 @@ public class PCollectionConsumerRegistry {
       // we have window optimization.
       this.sampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);
 
-      // Wrap the consumer with extra logic to set the metric container with 
the appropriate
-      // PTransform context. This ensures that user metrics obtain the 
pTransform ID when they are
-      // created. Also use the ExecutionStateTracker and enter an appropriate 
state to track the
-      // Process Bundle Execution time metric.
-      MetricsContainer oldContainer = 
metricsEnvironmentState.activate(metricsContainer);
+      // Use the ExecutionStateTracker and enter an appropriate state to track 
the
+      // Process Bundle Execution time metric and also ensure user counters 
can get an appropriate
+      // metrics container.
       executionState.activate();
       try {
         this.delegate.accept(input);
       } finally {
         executionState.deactivate();
-        metricsEnvironmentState.activate(oldContainer);
       }
       this.sampledByteSizeDistribution.finishLazyUpdate();
     }
@@ -329,18 +299,13 @@ public class PCollectionConsumerRegistry {
   private class MultiplexingMetricTrackingFnDataReceiver<T>
       implements FnDataReceiver<WindowedValue<T>> {
     private final List<ConsumerAndMetadata> consumerAndMetadatas;
-    private final MetricsEnvironmentState metricsEnvironmentState;
     private final BundleCounter elementCountCounter;
     private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
     private final Coder<T> coder;
 
     public MultiplexingMetricTrackingFnDataReceiver(
-        String pCollectionId,
-        Coder<T> coder,
-        List<ConsumerAndMetadata> consumerAndMetadatas,
-        MetricsEnvironmentState metricsEnvironmentState) {
+        String pCollectionId, Coder<T> coder, List<ConsumerAndMetadata> 
consumerAndMetadatas) {
       this.consumerAndMetadatas = consumerAndMetadatas;
-      this.metricsEnvironmentState = metricsEnvironmentState;
 
       HashMap<String, String> labels = new HashMap<>();
       labels.put(Labels.PCOLLECTION, pCollectionId);
@@ -384,20 +349,16 @@ public class PCollectionConsumerRegistry {
       // when we have window optimization.
       this.sampledByteSizeDistribution.tryUpdate(input.getValue(), coder);
 
-      // Wrap the consumer with extra logic to set the metric container with 
the appropriate
-      // PTransform context. This ensures that user metrics obtain the 
pTransform ID when they are
-      // created. Also use the ExecutionStateTracker and enter an appropriate 
state to track the
-      // Process Bundle Execution time metric.
+      // Use the ExecutionStateTracker and enter an appropriate state to track 
the
+      // Process Bundle Execution time metric and also ensure user counters 
can get an appropriate
+      // metrics container.
       for (ConsumerAndMetadata consumerAndMetadata : consumerAndMetadatas) {
-        MetricsContainer oldContainer =
-            
metricsEnvironmentState.activate(consumerAndMetadata.getMetricsContainer());
         ExecutionState state = consumerAndMetadata.getExecutionState();
         state.activate();
         try {
           consumerAndMetadata.getConsumer().accept(input);
         } finally {
           state.deactivate();
-          metricsEnvironmentState.activate(oldContainer);
         }
         this.sampledByteSizeDistribution.finishLazyUpdate();
       }
@@ -416,11 +377,8 @@ public class PCollectionConsumerRegistry {
     private final HandlesSplits delegate;
 
     public SplittingMetricTrackingFnDataReceiver(
-        String pCollection,
-        Coder<T> coder,
-        ConsumerAndMetadata consumerAndMetadata,
-        MetricsEnvironmentState metricsEnvironmentState) {
-      super(pCollection, coder, consumerAndMetadata, metricsEnvironmentState);
+        String pCollection, Coder<T> coder, ConsumerAndMetadata 
consumerAndMetadata) {
+      super(pCollection, coder, consumerAndMetadata);
       this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer();
     }
 
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
index f6bd008c424..ea0a9e76a28 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java
@@ -22,23 +22,21 @@ import java.util.List;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
 import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.function.ThrowingRunnable;
-import org.apache.beam.sdk.metrics.MetricsContainer;
-import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 
 /**
- * A class to to register and retrieve functions for bundle processing (i.e. 
the start, or finish
+ * A class to register and retrieve functions for bundle processing (i.e. the 
start, or finish
  * function). The purpose of this class is to wrap these functions with 
instrumentation for metrics
  * and other telemetry collection.
  *
- * <p>Usage: // Instantiate and use the registry for each class of functions. 
i.e. start. finish.
+ * <p>Usage:
  *
  * <pre>
+ * // Instantiate and use the registry for each class of functions. i.e. 
start. finish.
  * PTransformFunctionRegistry startFunctionRegistry;
  * PTransformFunctionRegistry finishFunctionRegistry;
  * startFunctionRegistry.register(myStartThrowingRunnable);
@@ -57,8 +55,6 @@ import 
org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
  */
 public class PTransformFunctionRegistry {
 
-  private final MetricsContainerStepMap metricsContainerRegistry;
-  private final MetricsEnvironmentState metricsEnvironmentState;
   private final ExecutionStateTracker stateTracker;
   private final String executionStateUrn;
   private final ShortIdMap shortIds;
@@ -68,20 +64,12 @@ public class PTransformFunctionRegistry {
   /**
    * Construct the registry to run for either start or finish bundle functions.
    *
-   * @param metricsContainerRegistry - Used to enable a metric container to 
properly account for the
-   *     pTransform in user metrics.
-   * @param metricsEnvironmentState - Used to activate which metrics container 
receives counter
-   *     updates.
    * @param shortIds - Provides short ids for {@link MonitoringInfo}.
    * @param stateTracker - The tracker to enter states in order to calculate 
execution time metrics.
    * @param executionStateUrn - The URN for the execution state .
    */
   public PTransformFunctionRegistry(
-      MetricsContainerStepMap metricsContainerRegistry,
-      MetricsEnvironmentState metricsEnvironmentState,
-      ShortIdMap shortIds,
-      ExecutionStateTracker stateTracker,
-      String executionStateUrn) {
+      ShortIdMap shortIds, ExecutionStateTracker stateTracker, String 
executionStateUrn) {
     switch (executionStateUrn) {
       case Urns.START_BUNDLE_MSECS:
         stateName = 
org.apache.beam.runners.core.metrics.ExecutionStateTracker.START_STATE_NAME;
@@ -92,8 +80,6 @@ public class PTransformFunctionRegistry {
       default:
         throw new IllegalArgumentException(String.format("Unknown URN %s", 
executionStateUrn));
     }
-    this.metricsContainerRegistry = metricsContainerRegistry;
-    this.metricsEnvironmentState = metricsEnvironmentState;
     this.shortIds = shortIds;
     this.executionStateUrn = executionStateUrn;
     this.stateTracker = stateTracker;
@@ -123,17 +109,13 @@ public class PTransformFunctionRegistry {
     ExecutionState executionState =
         stateTracker.create(shortId, pTransformId, pTransformUniqueName, 
stateName);
 
-    MetricsContainer container = 
metricsContainerRegistry.getContainer(pTransformId);
-
     ThrowingRunnable wrapped =
         () -> {
-          MetricsContainer oldContainer = 
metricsEnvironmentState.activate(container);
           executionState.activate();
           try {
             runnable.run();
           } finally {
             executionState.deactivate();
-            metricsEnvironmentState.activate(oldContainer);
           }
         };
     runnables.add(wrapped);
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
index 72a106a75a7..01cdc474e4e 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java
@@ -33,11 +33,21 @@ import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState;
 import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
 import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.DelegatingHistogram;
+import org.apache.beam.sdk.metrics.Distribution;
+import org.apache.beam.sdk.metrics.Gauge;
+import org.apache.beam.sdk.metrics.Histogram;
+import org.apache.beam.sdk.metrics.MetricName;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.testing.ExpectedLogs;
+import org.apache.beam.sdk.util.HistogramData;
 import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
 import org.joda.time.DateTimeUtils.MillisProvider;
 import org.joda.time.Duration;
+import org.junit.After;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -50,8 +60,21 @@ import org.mockito.stubbing.Answer;
 @RunWith(JUnit4.class)
 public class ExecutionStateSamplerTest {
 
+  private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", 
"counter");
+  private static final Distribution TEST_USER_DISTRIBUTION =
+      Metrics.distribution("foo", "distribution");
+  private static final Gauge TEST_USER_GAUGE = Metrics.gauge("foo", "gauge");
+  private static final Histogram TEST_USER_HISTOGRAM =
+      new DelegatingHistogram(
+          MetricName.named("foo", "histogram"), 
HistogramData.LinearBuckets.of(0, 100, 1), false);
+
   @Rule public ExpectedLogs expectedLogs = 
ExpectedLogs.none(ExecutionStateSampler.class);
 
+  @After
+  public void tearDown() {
+    MetricsEnvironment.setCurrentContainer(null);
+  }
+
   @Test
   public void testSamplingProducesCorrectFinalResults() throws Exception {
     MillisProvider clock = mock(MillisProvider.class);
@@ -335,6 +358,108 @@ public class ExecutionStateSamplerTest {
     expectedLogs.verifyNotLogged("Operation ongoing");
   }
 
+  @Test
+  public void testCountersReturnedAreBasedUponCurrentExecutionState() throws 
Exception {
+    MillisProvider clock = mock(MillisProvider.class);
+    ExecutionStateSampler sampler =
+        new ExecutionStateSampler(
+            
PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10")
+                .create(),
+            clock);
+    ExecutionStateTracker tracker = sampler.create();
+    MetricsEnvironment.setCurrentContainer(tracker.getMetricsContainer());
+    ExecutionState state = tracker.create("shortId", "ptransformId", 
"uniqueName", "state");
+
+    state.activate();
+    TEST_USER_COUNTER.inc();
+    TEST_USER_DISTRIBUTION.update(2);
+    TEST_USER_GAUGE.set(3);
+    TEST_USER_HISTOGRAM.update(4);
+    state.deactivate();
+
+    TEST_USER_COUNTER.inc(11);
+    TEST_USER_DISTRIBUTION.update(12);
+    TEST_USER_GAUGE.set(13);
+    TEST_USER_HISTOGRAM.update(14);
+    TEST_USER_HISTOGRAM.update(14);
+
+    // Verify the execution state was updated
+    assertEquals(
+        1L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getContainer("ptransformId")
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        2L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getContainer("ptransformId")
+                .getDistribution(TEST_USER_DISTRIBUTION.getName())
+                .getCumulative()
+                .sum());
+    assertEquals(
+        3L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getContainer("ptransformId")
+                .getGauge(TEST_USER_GAUGE.getName())
+                .getCumulative()
+                .value());
+    assertEquals(
+        1L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getContainer("ptransformId")
+                .getHistogram(
+                    TEST_USER_HISTOGRAM.getName(), 
HistogramData.LinearBuckets.of(0, 100, 1))
+                .getCumulative()
+                .getCount(0));
+
+    // Verify the unbound container
+    assertEquals(
+        11L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        12L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getDistribution(TEST_USER_DISTRIBUTION.getName())
+                .getCumulative()
+                .sum());
+    assertEquals(
+        13L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getGauge(TEST_USER_GAUGE.getName())
+                .getCumulative()
+                .value());
+    assertEquals(
+        2L,
+        (long)
+            tracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getHistogram(
+                    TEST_USER_HISTOGRAM.getName(), 
HistogramData.LinearBuckets.of(0, 100, 1))
+                .getCumulative()
+                .getCount(0));
+  }
+
   @Test
   public void testTrackerReuse() throws Exception {
     MillisProvider clock = mock(MillisProvider.class);
@@ -344,6 +469,7 @@ public class ExecutionStateSamplerTest {
                 .create(),
             clock);
     ExecutionStateTracker tracker = sampler.create();
+    MetricsEnvironment.setCurrentContainer(tracker.getMetricsContainer());
     ExecutionState state = tracker.create("shortId", "ptransformId", 
"ptransformIdName", "process");
 
     CountDownLatch waitTillActive = new CountDownLatch(1);
@@ -384,6 +510,7 @@ public class ExecutionStateSamplerTest {
       state.activate();
       waitTillActive.countDown();
       waitForSamples.await();
+      TEST_USER_COUNTER.inc();
       state.deactivate();
       Map<String, ByteString> finalResults = new HashMap<>();
       tracker.updateFinalMonitoringData(finalResults);
@@ -393,6 +520,14 @@ public class ExecutionStateSamplerTest {
           // The CountDownLatch ensures that we will see either the prior 
value or
           // the latest value.
           anyOf(equalTo(900L), equalTo(1000L)));
+      assertEquals(
+          1L,
+          (long)
+              tracker
+                  .getMetricsContainerRegistry()
+                  .getContainer("ptransformId")
+                  .getCounter(TEST_USER_COUNTER.getName())
+                  .getCumulative());
       tracker.reset();
     }
 
@@ -401,6 +536,7 @@ public class ExecutionStateSamplerTest {
       state.activate();
       waitTillSecondStateActive.countDown();
       waitForMoreSamples.await();
+      TEST_USER_COUNTER.inc();
       state.deactivate();
       Map<String, ByteString> finalResults = new HashMap<>();
       tracker.updateFinalMonitoringData(finalResults);
@@ -410,6 +546,14 @@ public class ExecutionStateSamplerTest {
           // The CountDownLatch ensures that we will see either the prior 
value or
           // the latest value.
           anyOf(equalTo(400L), equalTo(500L)));
+      assertEquals(
+          1L,
+          (long)
+              tracker
+                  .getMetricsContainerRegistry()
+                  .getContainer("ptransformId")
+                  .getCounter(TEST_USER_COUNTER.getName())
+                  .getCumulative());
       tracker.reset();
     }
 
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
index 404e63b3edf..7df9ed2f894 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java
@@ -115,7 +115,6 @@ import 
org.apache.beam.runners.core.construction.ModelCoders;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.ParDoTranslation;
 import org.apache.beam.runners.core.construction.Timer;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.coders.KvCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
@@ -126,6 +125,7 @@ import org.apache.beam.sdk.fn.data.TimerEndpoint;
 import org.apache.beam.sdk.fn.test.TestExecutors;
 import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService;
 import org.apache.beam.sdk.function.ThrowingRunnable;
+import org.apache.beam.sdk.metrics.MetricsEnvironment;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.state.TimeDomain;
 import org.apache.beam.sdk.state.TimerMap;
@@ -290,11 +290,6 @@ public class ProcessBundleHandlerTest {
       return wrappedBundleProcessor.getpCollectionConsumerRegistry();
     }
 
-    @Override
-    MetricsContainerStepMap getMetricsContainerRegistry() {
-      return wrappedBundleProcessor.getMetricsContainerRegistry();
-    }
-
     @Override
     MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle() {
       return wrappedBundleProcessor.getMetricsEnvironmentStateForBundle();
@@ -730,7 +725,6 @@ public class ProcessBundleHandlerTest {
     Collection<CallbackRegistration> bundleFinalizationCallbacks = 
mock(Collection.class);
     PCollectionConsumerRegistry pCollectionConsumerRegistry =
         mock(PCollectionConsumerRegistry.class);
-    MetricsContainerStepMap metricsContainerRegistry = 
mock(MetricsContainerStepMap.class);
     ExecutionStateTracker stateTracker = mock(ExecutionStateTracker.class);
     ProcessBundleHandler.HandleStateCallsForBundle beamFnStateClient =
         mock(ProcessBundleHandler.HandleStateCallsForBundle.class);
@@ -747,7 +741,6 @@ public class ProcessBundleHandlerTest {
             new ArrayList<>(),
             splitListener,
             pCollectionConsumerRegistry,
-            metricsContainerRegistry,
             new MetricsEnvironmentStateForBundle(),
             stateTracker,
             beamFnStateClient,
@@ -771,10 +764,10 @@ public class ProcessBundleHandlerTest {
     assertNull(bundleProcessor.getCacheTokens());
     assertNull(bundleCache.peek("A"));
     verify(splitListener, times(1)).clear();
-    verify(metricsContainerRegistry, times(1)).reset();
     verify(stateTracker, times(1)).reset();
     verify(bundleFinalizationCallbacks, times(1)).clear();
     verify(resetFunction, times(1)).run();
+    assertNull(MetricsEnvironment.getCurrentContainer());
 
     // Ensure that the next setup produces the expected state.
     bundleProcessor.setupForProcessBundleRequest(
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
index f65237c986e..35bd5697adc 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.fn.harness.data;
 import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
@@ -27,7 +28,6 @@ import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -38,12 +38,12 @@ import java.util.Map;
 import org.apache.beam.fn.harness.HandlesSplits;
 import org.apache.beam.fn.harness.control.BundleProgressReporter;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler;
+import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor;
 import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.runners.core.construction.SdkComponents;
 import org.apache.beam.runners.core.metrics.DistributionData;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
@@ -52,9 +52,9 @@ import 
org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder;
 import org.apache.beam.sdk.coders.IterableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.fn.data.FnDataReceiver;
-import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable;
@@ -68,8 +68,6 @@ import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
-import org.mockito.InOrder;
-import org.mockito.Mockito;
 import org.mockito.stubbing.Answer;
 
 /** Tests for {@link PCollectionConsumerRegistryTest}. */
@@ -78,6 +76,7 @@ import org.mockito.stubbing.Answer;
   "rawtypes", // TODO(https://github.com/apache/beam/issues/20447)
 })
 public class PCollectionConsumerRegistryTest {
+  private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", 
"bar");
 
   @Rule public ExpectedException expectedException = ExpectedException.none();
 
@@ -114,6 +113,7 @@ public class PCollectionConsumerRegistryTest {
 
   @After
   public void tearDown() throws Exception {
+    MetricsEnvironment.setCurrentContainer(null);
     sampler.stop();
   }
 
@@ -121,17 +121,11 @@ public class PCollectionConsumerRegistryTest {
   public void singleConsumer() throws Exception {
     final String pTransformIdA = "pTransformIdA";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = 
mock(FnDataReceiver.class);
 
     consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", 
consumerA1);
@@ -183,17 +177,11 @@ public class PCollectionConsumerRegistryTest {
     final String pTransformId = "pTransformId";
     final String message = "testException";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumer = 
mock(FnDataReceiver.class);
 
     consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", 
consumer);
@@ -211,17 +199,11 @@ public class PCollectionConsumerRegistryTest {
   /** Test that the counter increments even when there are no consumers of the 
PCollection. */
   @Test
   public void noConsumers() throws Exception {
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -271,17 +253,11 @@ public class PCollectionConsumerRegistryTest {
     final String pTransformIdA = "pTransformIdA";
     final String pTransformIdB = "pTransformIdB";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = 
mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = 
mock(FnDataReceiver.class);
 
@@ -336,17 +312,11 @@ public class PCollectionConsumerRegistryTest {
     final String pTransformId = "pTransformId";
     final String message = "testException";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = 
mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = 
mock(FnDataReceiver.class);
 
@@ -367,17 +337,11 @@ public class PCollectionConsumerRegistryTest {
   public void throwsOnRegisteringAfterMultiplexingConsumerWasInitialized() 
throws Exception {
     final String pTransformId = "pTransformId";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<String>> consumerA1 = 
mock(FnDataReceiver.class);
     FnDataReceiver<WindowedValue<String>> consumerA2 = 
mock(FnDataReceiver.class);
 
@@ -391,31 +355,19 @@ public class PCollectionConsumerRegistryTest {
 
   @Test
   public void testMetricContainerUpdatedUponAcceptingElement() throws 
Exception {
-    MetricsEnvironmentState metricsEnvironmentState = 
mock(MetricsEnvironmentState.class);
-
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
+    ExecutionStateTracker executionStateTracker = sampler.create();
+    
MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer());
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
+    executionStateTracker.start("testBundle");
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            metricsEnvironmentState,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
-    FnDataReceiver<WindowedValue<String>> consumerA1 = 
mock(FnDataReceiver.class);
-    FnDataReceiver<WindowedValue<String>> consumerA2 = 
mock(FnDataReceiver.class);
+            executionStateTracker, shortIds, reporterAndRegistrar, 
TEST_DESCRIPTOR);
 
-    consumers.register(P_COLLECTION_A, "pTransformA", "pTransformAName", 
consumerA1);
-    consumers.register(P_COLLECTION_A, "pTransformB", "pTransformBName", 
consumerA2);
-
-    // Test both cases; when there is an existing container and where there is 
no container
-    MetricsContainer oldContainer = mock(MetricsContainer.class);
-    
when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA")))
-        .thenReturn(oldContainer);
-    
when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB")))
-        .thenReturn(null);
+    consumers.register(
+        P_COLLECTION_A, "pTransformA", "pTransformAName", (unused) -> 
TEST_USER_COUNTER.inc());
+    consumers.register(
+        P_COLLECTION_A, "pTransformB", "pTransformBName", (unused) -> 
TEST_USER_COUNTER.inc(2));
 
     FnDataReceiver<WindowedValue<String>> wrapperConsumer =
         (FnDataReceiver<WindowedValue<String>>)
@@ -423,36 +375,45 @@ public class PCollectionConsumerRegistryTest {
 
     WindowedValue<String> element = valueInGlobalWindow("elem");
     wrapperConsumer.accept(element);
-
-    // Verify that metrics environment state is updated with pTransformA's 
container, then reset to
-    // the oldContainer, then pTransformB's container and then reset to null.
-    InOrder inOrder = Mockito.inOrder(metricsEnvironmentState);
-    inOrder
-        .verify(metricsEnvironmentState)
-        .activate(metricsContainerRegistry.getContainer("pTransformA"));
-    inOrder.verify(metricsEnvironmentState).activate(oldContainer);
-    inOrder
-        .verify(metricsEnvironmentState)
-        .activate(metricsContainerRegistry.getContainer("pTransformB"));
-    inOrder.verify(metricsEnvironmentState).activate(null);
-    inOrder.verifyNoMoreInteractions();
+    TEST_USER_COUNTER.inc(3);
+
+    // Verify that metrics environment state is updated with pTransform's 
counters including the
+    // unbound container when outside the scope of the function
+    assertEquals(
+        1L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getContainer("pTransformA")
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        2L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getContainer("pTransformB")
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        3L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
   }
 
   @Test
   public void testHandlesSplitsPassedToOriginalConsumer() throws Exception {
     final String pTransformIdA = "pTransformIdA";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     SplittingReceiver consumerA1 = mock(SplittingReceiver.class);
 
     consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", 
consumerA1);
@@ -474,17 +435,11 @@ public class PCollectionConsumerRegistryTest {
   public void testLazyByteSizeEstimation() throws Exception {
     final String pTransformIdA = "pTransformIdA";
 
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
     ShortIdMap shortIds = new ShortIdMap();
     BundleProgressReporter.InMemory reporterAndRegistrar = new 
BundleProgressReporter.InMemory();
     PCollectionConsumerRegistry consumers =
         new PCollectionConsumerRegistry(
-            metricsContainerRegistry,
-            MetricsEnvironment::setCurrentContainer,
-            sampler.create(),
-            shortIds,
-            reporterAndRegistrar,
-            TEST_DESCRIPTOR);
+            sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR);
     FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = 
mock(FnDataReceiver.class);
 
     consumers.register(P_COLLECTION_B, pTransformIdA, pTransformIdA + "Name", 
consumerA1);
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
index 7def4286258..6e06c16c653 100644
--- 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java
@@ -20,32 +20,28 @@ package org.apache.beam.fn.harness.data;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.beam.fn.harness.control.ExecutionStateSampler;
 import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker;
 import 
org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus;
-import org.apache.beam.runners.core.metrics.MetricsContainerStepMap;
 import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns;
 import org.apache.beam.runners.core.metrics.ShortIdMap;
 import org.apache.beam.sdk.function.ThrowingRunnable;
-import org.apache.beam.sdk.metrics.MetricsContainer;
+import org.apache.beam.sdk.metrics.Counter;
+import org.apache.beam.sdk.metrics.Metrics;
 import org.apache.beam.sdk.metrics.MetricsEnvironment;
-import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState;
 import org.apache.beam.sdk.options.PipelineOptionsFactory;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
-import org.mockito.InOrder;
-import org.mockito.Mockito;
 
 /** Tests for {@link PTransformFunctionRegistry}. */
 @RunWith(JUnit4.class)
 public class PTransformFunctionRegistryTest {
+  private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", 
"bar");
 
   private ExecutionStateSampler sampler;
 
@@ -56,19 +52,17 @@ public class PTransformFunctionRegistryTest {
 
   @After
   public void tearDown() {
+    MetricsEnvironment.setCurrentContainer(null);
     sampler.stop();
   }
 
   @Test
   public void testStateTrackerRecordsStateTransitions() throws Exception {
     ExecutionStateTracker executionStateTracker = sampler.create();
+    
MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer());
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
-            mock(MetricsContainerStepMap.class),
-            MetricsEnvironment::setCurrentContainer,
-            new ShortIdMap(),
-            executionStateTracker,
-            Urns.START_BUNDLE_MSECS);
+            new ShortIdMap(), executionStateTracker, Urns.START_BUNDLE_MSECS);
 
     final AtomicBoolean runnableAWasCalled = new AtomicBoolean();
     final AtomicBoolean runnableBWasCalled = new AtomicBoolean();
@@ -111,46 +105,49 @@ public class PTransformFunctionRegistryTest {
 
   @Test
   public void testMetricsUponRunningFunctions() throws Exception {
-    MetricsEnvironmentState metricsEnvironmentState = 
mock(MetricsEnvironmentState.class);
     ExecutionStateTracker executionStateTracker = sampler.create();
-    MetricsContainerStepMap metricsContainerRegistry = new 
MetricsContainerStepMap();
+    
MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer());
     PTransformFunctionRegistry testObject =
         new PTransformFunctionRegistry(
-            metricsContainerRegistry,
-            metricsEnvironmentState,
-            new ShortIdMap(),
-            executionStateTracker,
-            Urns.START_BUNDLE_MSECS);
+            new ShortIdMap(), executionStateTracker, Urns.START_BUNDLE_MSECS);
 
-    ThrowingRunnable runnableA = mock(ThrowingRunnable.class);
-    ThrowingRunnable runnableB = mock(ThrowingRunnable.class);
-    testObject.register("pTransformA", "pTranformAName", runnableA);
-    testObject.register("pTransformB", "pTranformBName", runnableB);
+    testObject.register("pTransformA", "pTranformAName", () -> 
TEST_USER_COUNTER.inc());
+    testObject.register("pTransformB", "pTranformBName", () -> 
TEST_USER_COUNTER.inc(2));
 
     // Test both cases; when there is an existing container and where there is 
no container
-    MetricsContainer oldContainer = mock(MetricsContainer.class);
-    
when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA")))
-        .thenReturn(oldContainer);
-    
when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB")))
-        .thenReturn(null);
-
     executionStateTracker.start("testBundleId");
     for (ThrowingRunnable func : testObject.getFunctions()) {
       func.run();
     }
-    executionStateTracker.reset();
+    TEST_USER_COUNTER.inc(3);
+
+    // Verify that metrics environment state is updated with pTransform's 
counters including the
+    // unbound container when outside the scope of the function
+    assertEquals(
+        1L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getContainer("pTransformA")
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        2L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getContainer("pTransformB")
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
+    assertEquals(
+        3L,
+        (long)
+            executionStateTracker
+                .getMetricsContainerRegistry()
+                .getUnboundContainer()
+                .getCounter(TEST_USER_COUNTER.getName())
+                .getCumulative());
 
-    // Verify that metrics environment state is updated with pTransformA's 
container, then reset to
-    // the oldContainer, then pTransformB's container and then reset to null.
-    InOrder inOrder = Mockito.inOrder(metricsEnvironmentState);
-    inOrder
-        .verify(metricsEnvironmentState)
-        .activate(metricsContainerRegistry.getContainer("pTransformA"));
-    inOrder.verify(metricsEnvironmentState).activate(oldContainer);
-    inOrder
-        .verify(metricsEnvironmentState)
-        .activate(metricsContainerRegistry.getContainer("pTransformB"));
-    inOrder.verify(metricsEnvironmentState).activate(null);
-    inOrder.verifyNoMoreInteractions();
+    executionStateTracker.reset();
   }
 }

Reply via email to