This is an automated email from the ASF dual-hosted git repository. lcwik pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 679d30256c6 Swap setting a context from being on the hot path when we emit elements to only be done during bundle creation and teardown #21250 (#25291) 679d30256c6 is described below commit 679d30256c6bd64d9760702c667d7d355e70166b Author: Luke Cwik <lc...@google.com> AuthorDate: Fri Feb 3 15:39:54 2023 -0800 Swap setting a context from being on the hot path when we emit elements to only be done during bundle creation and teardown #21250 (#25291) * Swap setting a context from being on the hot path when we emit elements to only be done during bundle creation and teardown. This cuts our per element processing overhead significantly down. For example in the ProcessBundleBenchmark.testLargeBundle we saw a reduction of about half of the execution state management overhead for each element (10.6% to 5.6% CPU overhead per element). The benchmark reflects this with a nice 11.5% performance improvement. Before: ``` Result "org.apache.beam.fn.harness.jmh.ProcessBundleBenchmark.testLargeBundle": 3265.141 ±(99.9%) 238.744 ops/s [Average] (min, avg, max) = (2952.616, 3265.141, 3525.482), stdev = 223.321 CI (99.9%): [3026.398, 3503.885] (assumes normal distribution) ``` After: ``` Result "org.apache.beam.fn.harness.jmh.ProcessBundleBenchmark.testLargeBundle": 3642.512 ±(99.9%) 45.865 ops/s [Average] (min, avg, max) = (3582.394, 3642.512, 3713.820), stdev = 42.902 CI (99.9%): [3596.647, 3688.377] (assumes normal distribution) ``` --- .../fn/harness/control/ExecutionStateSampler.java | 109 +++++++++++++++- .../fn/harness/control/ProcessBundleHandler.java | 57 +++----- .../harness/data/PCollectionConsumerRegistry.java | 74 +++-------- .../harness/data/PTransformFunctionRegistry.java | 26 +--- .../harness/control/ExecutionStateSamplerTest.java | 144 ++++++++++++++++++++ .../harness/control/ProcessBundleHandlerTest.java | 11 +- .../data/PCollectionConsumerRegistryTest.java | 145 +++++++-------------- .../data/PTransformFunctionRegistryTest.java | 81 ++++++------ 8 files changed, 382 insertions(+), 265 deletions(-) diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java index f528a4a9919..9a94ce05057 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ExecutionStateSampler.java @@ -32,10 +32,19 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.concurrent.GuardedBy; import org.apache.beam.fn.harness.control.ProcessBundleHandler.BundleProcessor; +import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; +import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Gauge; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.MetricsContainer; import org.apache.beam.sdk.options.ExecutorOptions; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner; import org.checkerframework.checker.nullness.qual.Nullable; @@ -163,9 +172,66 @@ public class ExecutionStateSampler { return new ExecutionStateTracker(); } + /** + * A {@link MetricsContainer} that uses the current {@link ExecutionState} tracked by the provided + * {@link ExecutionStateTracker}. + */ + private static class MetricsContainerForTracker implements MetricsContainer { + private final transient ExecutionStateTracker tracker; + + private MetricsContainerForTracker(ExecutionStateTracker tracker) { + this.tracker = tracker; + } + + @Override + public Counter getCounter(MetricName metricName) { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getCounter(metricName); + } + return tracker.metricsContainerRegistry.getUnboundContainer().getCounter(metricName); + } + + @Override + public Distribution getDistribution(MetricName metricName) { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getDistribution(metricName); + } + return tracker.metricsContainerRegistry.getUnboundContainer().getDistribution(metricName); + } + + @Override + public Gauge getGauge(MetricName metricName) { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getGauge(metricName); + } + return tracker.metricsContainerRegistry.getUnboundContainer().getGauge(metricName); + } + + @Override + public Histogram getHistogram(MetricName metricName, HistogramData.BucketType bucketType) { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getHistogram(metricName, bucketType); + } + return tracker + .metricsContainerRegistry + .getUnboundContainer() + .getHistogram(metricName, bucketType); + } + + @Override + public Iterable<MonitoringInfo> getMonitoringInfos() { + if (tracker.currentState != null) { + return tracker.currentState.metricsContainer.getMonitoringInfos(); + } + return tracker.metricsContainerRegistry.getUnboundContainer().getMonitoringInfos(); + } + } + /** Tracks the current state of a single execution thread. */ public class ExecutionStateTracker implements BundleProgressReporter { - + // Used to create and store metrics containers for the execution states. + private final MetricsContainerStepMap metricsContainerRegistry; + private final MetricsContainer metricsContainer; // The set of execution states that this tracker is responsible for. Effectively // final since create() should not be invoked once any bundle starts processing. private final List<ExecutionStateImpl> executionStates; @@ -187,13 +253,36 @@ public class ExecutionStateSampler { // Read and written by the ExecutionStateSampler thread private long transitionsAtLastSample; + // Ignore the @UnderInitialization for ExecutionStateTracker since it will be initialized by the + // time this method returns and no references are leaked to other threads during construction. + @SuppressWarnings({"assignment", "argument"}) private ExecutionStateTracker() { + this.metricsContainerRegistry = new MetricsContainerStepMap(); this.executionStates = new ArrayList<>(); this.trackedThread = new AtomicReference<>(); this.lastTransitionTime = new AtomicLong(); this.numTransitionsLazy = new AtomicLong(); this.currentStateLazy = new AtomicReference<>(); this.processBundleId = new AtomicReference<>(); + this.metricsContainer = new MetricsContainerForTracker(this); + } + + /** + * Returns the {@link MetricsContainerStepMap} that is managed by this {@link + * ExecutionStateTracker}. This metrics container registry stores all the user counters + * associated for the current bundle execution. + */ + public MetricsContainerStepMap getMetricsContainerRegistry() { + return metricsContainerRegistry; + } + + /** + * Returns a {@link MetricsContainer} that delegates based upon the current execution state to + * the appropriate metrics container that is bound to the current {@link ExecutionState} or to + * the unbound {@link MetricsContainer} if no execution state is currently running. + */ + public MetricsContainer getMetricsContainer() { + return metricsContainer; } /** @@ -203,7 +292,12 @@ public class ExecutionStateSampler { public ExecutionState create( String shortId, String ptransformId, String ptransformUniqueName, String stateName) { ExecutionStateImpl newState = - new ExecutionStateImpl(shortId, ptransformId, ptransformUniqueName, stateName); + new ExecutionStateImpl( + shortId, + ptransformId, + ptransformUniqueName, + stateName, + metricsContainerRegistry.getContainer(ptransformId)); executionStates.add(newState); return newState; } @@ -285,10 +379,13 @@ public class ExecutionStateSampler { /** {@link ExecutionState} represents the current state of an execution thread. */ private class ExecutionStateImpl implements ExecutionState { + private final String shortId; private final String ptransformId; private final String ptransformUniqueName; private final String stateName; + private final MetricsContainer metricsContainer; + // Read and written by the bundle processing thread frequently. private long msecs; // Read by the ExecutionStateSampler, written by the bundle processing thread frequently. @@ -301,11 +398,16 @@ public class ExecutionStateSampler { private @Nullable ExecutionStateImpl previousState; private ExecutionStateImpl( - String shortId, String ptransformId, String ptransformName, String stateName) { + String shortId, + String ptransformId, + String ptransformName, + String stateName, + MetricsContainer metricsContainer) { this.shortId = shortId; this.ptransformId = ptransformId; this.ptransformUniqueName = ptransformName; this.stateName = stateName; + this.metricsContainer = metricsContainer; this.lazyMsecs = new AtomicLong(); } @@ -406,6 +508,7 @@ public class ExecutionStateSampler { this.numTransitions = 0; this.numTransitionsLazy.lazySet(0); this.lastTransitionTime.lazySet(0); + this.metricsContainerRegistry.reset(); } } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java index f2b95450391..560369a3907 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/control/ProcessBundleHandler.java @@ -74,7 +74,6 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.WindowingStrategy; import org.apache.beam.runners.core.construction.BeamUrns; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.Timer; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns; import org.apache.beam.runners.core.metrics.ShortIdMap; import org.apache.beam.sdk.fn.data.BeamFnDataInboundObserver; @@ -517,7 +516,6 @@ public class ProcessBundleHandler { ExecutionStateTracker stateTracker = bundleProcessor.getStateTracker(); try (HandleStateCallsForBundle beamFnStateClient = bundleProcessor.getBeamFnStateClient()) { - bundleProcessor.getMetricsEnvironmentStateForBundle().start(); stateTracker.start(request.getInstructionId()); try { // Already in reverse topological order so we don't need to do anything. @@ -688,7 +686,10 @@ public class ProcessBundleHandler { Map<String, ByteString> monitoringData = new HashMap<>(); // Extract MonitoringInfos that come from the metrics container registry. monitoringData.putAll( - bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds)); + bundleProcessor + .getStateTracker() + .getMetricsContainerRegistry() + .getMonitoringData(shortIds)); // Add any additional monitoring infos that the "runners" report explicitly. bundleProcessor .getBundleProgressReporterAndRegistrar() @@ -701,7 +702,10 @@ public class ProcessBundleHandler { HashMap<String, ByteString> monitoringData = new HashMap<>(); // Extract MonitoringInfos that come from the metrics container registry. monitoringData.putAll( - bundleProcessor.getMetricsContainerRegistry().getMonitoringData(shortIds)); + bundleProcessor + .getStateTracker() + .getMetricsContainerRegistry() + .getMonitoringData(shortIds)); // Add any additional monitoring infos that the "runners" report explicitly. bundleProcessor .getBundleProgressReporterAndRegistrar() @@ -736,21 +740,22 @@ public class ProcessBundleHandler { } @VisibleForTesting - static class MetricsEnvironmentStateForBundle implements MetricsEnvironmentState { + static class MetricsEnvironmentStateForBundle { private @Nullable MetricsEnvironmentState currentThreadState; - @Override - public @Nullable MetricsContainer activate(@Nullable MetricsContainer metricsContainer) { - return currentThreadState.activate(metricsContainer); - } - - public void start() { + public void start(MetricsContainer container) { currentThreadState = MetricsEnvironment.getMetricsEnvironmentStateForCurrentThread(); + currentThreadState.activate(container); } public void reset() { + currentThreadState.activate(null); currentThreadState = null; } + + public void discard() { + currentThreadState.activate(null); + } } private BundleProcessor createBundleProcessor( @@ -760,35 +765,19 @@ public class ProcessBundleHandler { SetMultimap<String, String> pCollectionIdsToConsumingPTransforms = HashMultimap.create(); BundleProgressReporter.InMemory bundleProgressReporterAndRegistrar = new BundleProgressReporter.InMemory(); - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle = new MetricsEnvironmentStateForBundle(); ExecutionStateTracker stateTracker = executionStateSampler.create(); bundleProgressReporterAndRegistrar.register(stateTracker); PCollectionConsumerRegistry pCollectionConsumerRegistry = new PCollectionConsumerRegistry( - metricsContainerRegistry, - metricsEnvironmentStateForBundle, - stateTracker, - shortIds, - bundleProgressReporterAndRegistrar, - bundleDescriptor); + stateTracker, shortIds, bundleProgressReporterAndRegistrar, bundleDescriptor); HashSet<String> processedPTransformIds = new HashSet<>(); PTransformFunctionRegistry startFunctionRegistry = - new PTransformFunctionRegistry( - metricsContainerRegistry, - metricsEnvironmentStateForBundle, - shortIds, - stateTracker, - Urns.START_BUNDLE_MSECS); + new PTransformFunctionRegistry(shortIds, stateTracker, Urns.START_BUNDLE_MSECS); PTransformFunctionRegistry finishFunctionRegistry = - new PTransformFunctionRegistry( - metricsContainerRegistry, - metricsEnvironmentStateForBundle, - shortIds, - stateTracker, - Urns.FINISH_BUNDLE_MSECS); + new PTransformFunctionRegistry(shortIds, stateTracker, Urns.FINISH_BUNDLE_MSECS); List<ThrowingRunnable> resetFunctions = new ArrayList<>(); List<ThrowingRunnable> tearDownFunctions = new ArrayList<>(); @@ -835,7 +824,6 @@ public class ProcessBundleHandler { tearDownFunctions, splitListener, pCollectionConsumerRegistry, - metricsContainerRegistry, metricsEnvironmentStateForBundle, stateTracker, beamFnStateClient, @@ -1028,7 +1016,6 @@ public class ProcessBundleHandler { List<ThrowingRunnable> tearDownFunctions, BundleSplitListener.InMemory splitListener, PCollectionConsumerRegistry pCollectionConsumerRegistry, - MetricsContainerStepMap metricsContainerRegistry, MetricsEnvironmentStateForBundle metricsEnvironmentStateForBundle, ExecutionStateTracker stateTracker, HandleStateCallsForBundle beamFnStateClient, @@ -1044,7 +1031,6 @@ public class ProcessBundleHandler { tearDownFunctions, splitListener, pCollectionConsumerRegistry, - metricsContainerRegistry, metricsEnvironmentStateForBundle, stateTracker, beamFnStateClient, @@ -1081,8 +1067,6 @@ public class ProcessBundleHandler { abstract PCollectionConsumerRegistry getpCollectionConsumerRegistry(); - abstract MetricsContainerStepMap getMetricsContainerRegistry(); - abstract MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle(); public abstract ExecutionStateTracker getStateTracker(); @@ -1140,6 +1124,7 @@ public class ProcessBundleHandler { synchronized void setupForProcessBundleRequest(InstructionRequest request) { this.instructionId = request.getInstructionId(); this.cacheTokens = request.getProcessBundle().getCacheTokensList(); + getMetricsEnvironmentStateForBundle().start(getStateTracker().getMetricsContainer()); } void reset() throws Exception { @@ -1152,7 +1137,6 @@ public class ProcessBundleHandler { } } getSplitListener().clear(); - getMetricsContainerRegistry().reset(); getMetricsEnvironmentStateForBundle().reset(); getStateTracker().reset(); getBundleFinalizationCallbackRegistrations().clear(); @@ -1171,6 +1155,7 @@ public class ProcessBundleHandler { if (this.bundleCache != null) { this.bundleCache.clear(); } + getMetricsEnvironmentStateForBundle().discard(); for (BeamFnDataOutboundAggregator aggregator : getOutboundAggregators().values()) { aggregator.discard(); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java index 556076ed3b1..45298a68d98 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistry.java @@ -35,7 +35,6 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.construction.RehydratedComponents; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.TypeUrns; @@ -46,8 +45,6 @@ import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.metrics.Distribution; -import org.apache.beam.sdk.metrics.MetricsContainer; -import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; @@ -71,12 +68,9 @@ public class PCollectionConsumerRegistry { @SuppressWarnings({"rawtypes"}) abstract static class ConsumerAndMetadata { public static ConsumerAndMetadata forConsumer( - FnDataReceiver consumer, - String pTransformId, - ExecutionState state, - MetricsContainer metricsContainer) { + FnDataReceiver consumer, String pTransformId, ExecutionState state) { return new AutoValue_PCollectionConsumerRegistry_ConsumerAndMetadata( - consumer, pTransformId, state, metricsContainer); + consumer, pTransformId, state); } public abstract FnDataReceiver getConsumer(); @@ -84,12 +78,8 @@ public class PCollectionConsumerRegistry { public abstract String getPTransformId(); public abstract ExecutionState getExecutionState(); - - public abstract MetricsContainer getMetricsContainer(); } - private final MetricsContainerStepMap metricsContainerRegistry; - private final MetricsEnvironmentState metricsEnvironmentState; private final ExecutionStateTracker stateTracker; private final ShortIdMap shortIdMap; private final Map<String, List<ConsumerAndMetadata>> pCollectionIdsToConsumers; @@ -99,14 +89,10 @@ public class PCollectionConsumerRegistry { private final RehydratedComponents rehydratedComponents; public PCollectionConsumerRegistry( - MetricsContainerStepMap metricsContainerRegistry, - MetricsEnvironmentState metricsEnvironmentState, ExecutionStateTracker stateTracker, ShortIdMap shortIdMap, BundleProgressReporter.Registrar bundleProgressReporterRegistrar, ProcessBundleDescriptor processBundleDescriptor) { - this.metricsContainerRegistry = metricsContainerRegistry; - this.metricsEnvironmentState = metricsEnvironmentState; this.stateTracker = stateTracker; this.shortIdMap = shortIdMap; this.pCollectionIdsToConsumers = new HashMap<>(); @@ -176,11 +162,7 @@ public class PCollectionConsumerRegistry { List<ConsumerAndMetadata> consumerAndMetadatas = pCollectionIdsToConsumers.computeIfAbsent(pCollectionId, (unused) -> new ArrayList<>()); consumerAndMetadatas.add( - ConsumerAndMetadata.forConsumer( - consumer, - pTransformId, - executionState, - metricsContainerRegistry.getContainer(pTransformId))); + ConsumerAndMetadata.forConsumer(consumer, pTransformId, executionState)); } /** @@ -219,18 +201,16 @@ public class PCollectionConsumerRegistry { if (consumerAndMetadatas.size() == 1) { ConsumerAndMetadata consumerAndMetadata = consumerAndMetadatas.get(0); if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) { - return new SplittingMetricTrackingFnDataReceiver( - pcId, coder, consumerAndMetadata, metricsEnvironmentState); + return new SplittingMetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata); } - return new MetricTrackingFnDataReceiver( - pcId, coder, consumerAndMetadata, metricsEnvironmentState); + return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata); } else { /* TODO(SDF), Consider supporting splitting each consumer individually. This would never come up in the existing SDF expansion, but might be useful to support fused SDF nodes. This would require dedicated delivery of the split results to each of the consumers separately. */ return new MultiplexingMetricTrackingFnDataReceiver( - pcId, coder, ImmutableList.copyOf(consumerAndMetadatas), metricsEnvironmentState); + pcId, coder, ImmutableList.copyOf(consumerAndMetadatas)); } }); } @@ -248,14 +228,9 @@ public class PCollectionConsumerRegistry { private final BundleCounter elementCountCounter; private final SampleByteSizeDistribution<T> sampledByteSizeDistribution; private final Coder<T> coder; - private final MetricsContainer metricsContainer; - private final MetricsEnvironmentState metricsEnvironmentState; public MetricTrackingFnDataReceiver( - String pCollectionId, - Coder<T> coder, - ConsumerAndMetadata consumerAndMetadata, - MetricsEnvironmentState metricsEnvironmentState) { + String pCollectionId, Coder<T> coder, ConsumerAndMetadata consumerAndMetadata) { this.delegate = consumerAndMetadata.getConsumer(); this.executionState = consumerAndMetadata.getExecutionState(); @@ -291,8 +266,6 @@ public class PCollectionConsumerRegistry { bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution); this.coder = coder; - this.metricsContainer = consumerAndMetadata.getMetricsContainer(); - this.metricsEnvironmentState = metricsEnvironmentState; } @Override @@ -303,17 +276,14 @@ public class PCollectionConsumerRegistry { // we have window optimization. this.sampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder); - // Wrap the consumer with extra logic to set the metric container with the appropriate - // PTransform context. This ensures that user metrics obtain the pTransform ID when they are - // created. Also use the ExecutionStateTracker and enter an appropriate state to track the - // Process Bundle Execution time metric. - MetricsContainer oldContainer = metricsEnvironmentState.activate(metricsContainer); + // Use the ExecutionStateTracker and enter an appropriate state to track the + // Process Bundle Execution time metric and also ensure user counters can get an appropriate + // metrics container. executionState.activate(); try { this.delegate.accept(input); } finally { executionState.deactivate(); - metricsEnvironmentState.activate(oldContainer); } this.sampledByteSizeDistribution.finishLazyUpdate(); } @@ -329,18 +299,13 @@ public class PCollectionConsumerRegistry { private class MultiplexingMetricTrackingFnDataReceiver<T> implements FnDataReceiver<WindowedValue<T>> { private final List<ConsumerAndMetadata> consumerAndMetadatas; - private final MetricsEnvironmentState metricsEnvironmentState; private final BundleCounter elementCountCounter; private final SampleByteSizeDistribution<T> sampledByteSizeDistribution; private final Coder<T> coder; public MultiplexingMetricTrackingFnDataReceiver( - String pCollectionId, - Coder<T> coder, - List<ConsumerAndMetadata> consumerAndMetadatas, - MetricsEnvironmentState metricsEnvironmentState) { + String pCollectionId, Coder<T> coder, List<ConsumerAndMetadata> consumerAndMetadatas) { this.consumerAndMetadatas = consumerAndMetadatas; - this.metricsEnvironmentState = metricsEnvironmentState; HashMap<String, String> labels = new HashMap<>(); labels.put(Labels.PCOLLECTION, pCollectionId); @@ -384,20 +349,16 @@ public class PCollectionConsumerRegistry { // when we have window optimization. this.sampledByteSizeDistribution.tryUpdate(input.getValue(), coder); - // Wrap the consumer with extra logic to set the metric container with the appropriate - // PTransform context. This ensures that user metrics obtain the pTransform ID when they are - // created. Also use the ExecutionStateTracker and enter an appropriate state to track the - // Process Bundle Execution time metric. + // Use the ExecutionStateTracker and enter an appropriate state to track the + // Process Bundle Execution time metric and also ensure user counters can get an appropriate + // metrics container. for (ConsumerAndMetadata consumerAndMetadata : consumerAndMetadatas) { - MetricsContainer oldContainer = - metricsEnvironmentState.activate(consumerAndMetadata.getMetricsContainer()); ExecutionState state = consumerAndMetadata.getExecutionState(); state.activate(); try { consumerAndMetadata.getConsumer().accept(input); } finally { state.deactivate(); - metricsEnvironmentState.activate(oldContainer); } this.sampledByteSizeDistribution.finishLazyUpdate(); } @@ -416,11 +377,8 @@ public class PCollectionConsumerRegistry { private final HandlesSplits delegate; public SplittingMetricTrackingFnDataReceiver( - String pCollection, - Coder<T> coder, - ConsumerAndMetadata consumerAndMetadata, - MetricsEnvironmentState metricsEnvironmentState) { - super(pCollection, coder, consumerAndMetadata, metricsEnvironmentState); + String pCollection, Coder<T> coder, ConsumerAndMetadata consumerAndMetadata) { + super(pCollection, coder, consumerAndMetadata); this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer(); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java index f6bd008c424..ea0a9e76a28 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistry.java @@ -22,23 +22,21 @@ import java.util.List; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns; import org.apache.beam.runners.core.metrics.ShortIdMap; import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder; import org.apache.beam.sdk.function.ThrowingRunnable; -import org.apache.beam.sdk.metrics.MetricsContainer; -import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; /** - * A class to to register and retrieve functions for bundle processing (i.e. the start, or finish + * A class to register and retrieve functions for bundle processing (i.e. the start, or finish * function). The purpose of this class is to wrap these functions with instrumentation for metrics * and other telemetry collection. * - * <p>Usage: // Instantiate and use the registry for each class of functions. i.e. start. finish. + * <p>Usage: * * <pre> + * // Instantiate and use the registry for each class of functions. i.e. start. finish. * PTransformFunctionRegistry startFunctionRegistry; * PTransformFunctionRegistry finishFunctionRegistry; * startFunctionRegistry.register(myStartThrowingRunnable); @@ -57,8 +55,6 @@ import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; */ public class PTransformFunctionRegistry { - private final MetricsContainerStepMap metricsContainerRegistry; - private final MetricsEnvironmentState metricsEnvironmentState; private final ExecutionStateTracker stateTracker; private final String executionStateUrn; private final ShortIdMap shortIds; @@ -68,20 +64,12 @@ public class PTransformFunctionRegistry { /** * Construct the registry to run for either start or finish bundle functions. * - * @param metricsContainerRegistry - Used to enable a metric container to properly account for the - * pTransform in user metrics. - * @param metricsEnvironmentState - Used to activate which metrics container receives counter - * updates. * @param shortIds - Provides short ids for {@link MonitoringInfo}. * @param stateTracker - The tracker to enter states in order to calculate execution time metrics. * @param executionStateUrn - The URN for the execution state . */ public PTransformFunctionRegistry( - MetricsContainerStepMap metricsContainerRegistry, - MetricsEnvironmentState metricsEnvironmentState, - ShortIdMap shortIds, - ExecutionStateTracker stateTracker, - String executionStateUrn) { + ShortIdMap shortIds, ExecutionStateTracker stateTracker, String executionStateUrn) { switch (executionStateUrn) { case Urns.START_BUNDLE_MSECS: stateName = org.apache.beam.runners.core.metrics.ExecutionStateTracker.START_STATE_NAME; @@ -92,8 +80,6 @@ public class PTransformFunctionRegistry { default: throw new IllegalArgumentException(String.format("Unknown URN %s", executionStateUrn)); } - this.metricsContainerRegistry = metricsContainerRegistry; - this.metricsEnvironmentState = metricsEnvironmentState; this.shortIds = shortIds; this.executionStateUrn = executionStateUrn; this.stateTracker = stateTracker; @@ -123,17 +109,13 @@ public class PTransformFunctionRegistry { ExecutionState executionState = stateTracker.create(shortId, pTransformId, pTransformUniqueName, stateName); - MetricsContainer container = metricsContainerRegistry.getContainer(pTransformId); - ThrowingRunnable wrapped = () -> { - MetricsContainer oldContainer = metricsEnvironmentState.activate(container); executionState.activate(); try { runnable.run(); } finally { executionState.deactivate(); - metricsEnvironmentState.activate(oldContainer); } }; runnables.add(wrapped); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java index 72a106a75a7..01cdc474e4e 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ExecutionStateSamplerTest.java @@ -33,11 +33,21 @@ import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionState; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus; import org.apache.beam.runners.core.metrics.MonitoringInfoEncodings; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.DelegatingHistogram; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Gauge; +import org.apache.beam.sdk.metrics.Histogram; +import org.apache.beam.sdk.metrics.MetricName; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.testing.ExpectedLogs; +import org.apache.beam.sdk.util.HistogramData; import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString; import org.joda.time.DateTimeUtils.MillisProvider; import org.joda.time.Duration; +import org.junit.After; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -50,8 +60,21 @@ import org.mockito.stubbing.Answer; @RunWith(JUnit4.class) public class ExecutionStateSamplerTest { + private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", "counter"); + private static final Distribution TEST_USER_DISTRIBUTION = + Metrics.distribution("foo", "distribution"); + private static final Gauge TEST_USER_GAUGE = Metrics.gauge("foo", "gauge"); + private static final Histogram TEST_USER_HISTOGRAM = + new DelegatingHistogram( + MetricName.named("foo", "histogram"), HistogramData.LinearBuckets.of(0, 100, 1), false); + @Rule public ExpectedLogs expectedLogs = ExpectedLogs.none(ExecutionStateSampler.class); + @After + public void tearDown() { + MetricsEnvironment.setCurrentContainer(null); + } + @Test public void testSamplingProducesCorrectFinalResults() throws Exception { MillisProvider clock = mock(MillisProvider.class); @@ -335,6 +358,108 @@ public class ExecutionStateSamplerTest { expectedLogs.verifyNotLogged("Operation ongoing"); } + @Test + public void testCountersReturnedAreBasedUponCurrentExecutionState() throws Exception { + MillisProvider clock = mock(MillisProvider.class); + ExecutionStateSampler sampler = + new ExecutionStateSampler( + PipelineOptionsFactory.fromArgs("--experiments=state_sampling_period_millis=10") + .create(), + clock); + ExecutionStateTracker tracker = sampler.create(); + MetricsEnvironment.setCurrentContainer(tracker.getMetricsContainer()); + ExecutionState state = tracker.create("shortId", "ptransformId", "uniqueName", "state"); + + state.activate(); + TEST_USER_COUNTER.inc(); + TEST_USER_DISTRIBUTION.update(2); + TEST_USER_GAUGE.set(3); + TEST_USER_HISTOGRAM.update(4); + state.deactivate(); + + TEST_USER_COUNTER.inc(11); + TEST_USER_DISTRIBUTION.update(12); + TEST_USER_GAUGE.set(13); + TEST_USER_HISTOGRAM.update(14); + TEST_USER_HISTOGRAM.update(14); + + // Verify the execution state was updated + assertEquals( + 1L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 2L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getDistribution(TEST_USER_DISTRIBUTION.getName()) + .getCumulative() + .sum()); + assertEquals( + 3L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getGauge(TEST_USER_GAUGE.getName()) + .getCumulative() + .value()); + assertEquals( + 1L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getHistogram( + TEST_USER_HISTOGRAM.getName(), HistogramData.LinearBuckets.of(0, 100, 1)) + .getCumulative() + .getCount(0)); + + // Verify the unbound container + assertEquals( + 11L, + (long) + tracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 12L, + (long) + tracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getDistribution(TEST_USER_DISTRIBUTION.getName()) + .getCumulative() + .sum()); + assertEquals( + 13L, + (long) + tracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getGauge(TEST_USER_GAUGE.getName()) + .getCumulative() + .value()); + assertEquals( + 2L, + (long) + tracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getHistogram( + TEST_USER_HISTOGRAM.getName(), HistogramData.LinearBuckets.of(0, 100, 1)) + .getCumulative() + .getCount(0)); + } + @Test public void testTrackerReuse() throws Exception { MillisProvider clock = mock(MillisProvider.class); @@ -344,6 +469,7 @@ public class ExecutionStateSamplerTest { .create(), clock); ExecutionStateTracker tracker = sampler.create(); + MetricsEnvironment.setCurrentContainer(tracker.getMetricsContainer()); ExecutionState state = tracker.create("shortId", "ptransformId", "ptransformIdName", "process"); CountDownLatch waitTillActive = new CountDownLatch(1); @@ -384,6 +510,7 @@ public class ExecutionStateSamplerTest { state.activate(); waitTillActive.countDown(); waitForSamples.await(); + TEST_USER_COUNTER.inc(); state.deactivate(); Map<String, ByteString> finalResults = new HashMap<>(); tracker.updateFinalMonitoringData(finalResults); @@ -393,6 +520,14 @@ public class ExecutionStateSamplerTest { // The CountDownLatch ensures that we will see either the prior value or // the latest value. anyOf(equalTo(900L), equalTo(1000L))); + assertEquals( + 1L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); tracker.reset(); } @@ -401,6 +536,7 @@ public class ExecutionStateSamplerTest { state.activate(); waitTillSecondStateActive.countDown(); waitForMoreSamples.await(); + TEST_USER_COUNTER.inc(); state.deactivate(); Map<String, ByteString> finalResults = new HashMap<>(); tracker.updateFinalMonitoringData(finalResults); @@ -410,6 +546,14 @@ public class ExecutionStateSamplerTest { // The CountDownLatch ensures that we will see either the prior value or // the latest value. anyOf(equalTo(400L), equalTo(500L))); + assertEquals( + 1L, + (long) + tracker + .getMetricsContainerRegistry() + .getContainer("ptransformId") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); tracker.reset(); } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java index 404e63b3edf..7df9ed2f894 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/control/ProcessBundleHandlerTest.java @@ -115,7 +115,6 @@ import org.apache.beam.runners.core.construction.ModelCoders; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ParDoTranslation; import org.apache.beam.runners.core.construction.Timer; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.ShortIdMap; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -126,6 +125,7 @@ import org.apache.beam.sdk.fn.data.TimerEndpoint; import org.apache.beam.sdk.fn.test.TestExecutors; import org.apache.beam.sdk.fn.test.TestExecutors.TestExecutorService; import org.apache.beam.sdk.function.ThrowingRunnable; +import org.apache.beam.sdk.metrics.MetricsEnvironment; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.state.TimerMap; @@ -290,11 +290,6 @@ public class ProcessBundleHandlerTest { return wrappedBundleProcessor.getpCollectionConsumerRegistry(); } - @Override - MetricsContainerStepMap getMetricsContainerRegistry() { - return wrappedBundleProcessor.getMetricsContainerRegistry(); - } - @Override MetricsEnvironmentStateForBundle getMetricsEnvironmentStateForBundle() { return wrappedBundleProcessor.getMetricsEnvironmentStateForBundle(); @@ -730,7 +725,6 @@ public class ProcessBundleHandlerTest { Collection<CallbackRegistration> bundleFinalizationCallbacks = mock(Collection.class); PCollectionConsumerRegistry pCollectionConsumerRegistry = mock(PCollectionConsumerRegistry.class); - MetricsContainerStepMap metricsContainerRegistry = mock(MetricsContainerStepMap.class); ExecutionStateTracker stateTracker = mock(ExecutionStateTracker.class); ProcessBundleHandler.HandleStateCallsForBundle beamFnStateClient = mock(ProcessBundleHandler.HandleStateCallsForBundle.class); @@ -747,7 +741,6 @@ public class ProcessBundleHandlerTest { new ArrayList<>(), splitListener, pCollectionConsumerRegistry, - metricsContainerRegistry, new MetricsEnvironmentStateForBundle(), stateTracker, beamFnStateClient, @@ -771,10 +764,10 @@ public class ProcessBundleHandlerTest { assertNull(bundleProcessor.getCacheTokens()); assertNull(bundleCache.peek("A")); verify(splitListener, times(1)).clear(); - verify(metricsContainerRegistry, times(1)).reset(); verify(stateTracker, times(1)).reset(); verify(bundleFinalizationCallbacks, times(1)).clear(); verify(resetFunction, times(1)).run(); + assertNull(MetricsEnvironment.getCurrentContainer()); // Ensure that the next setup produces the expected state. bundleProcessor.setupForProcessBundleRequest( diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java index f65237c986e..35bd5697adc 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PCollectionConsumerRegistryTest.java @@ -20,6 +20,7 @@ package org.apache.beam.fn.harness.data; import static org.apache.beam.sdk.util.WindowedValue.valueInGlobalWindow; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; @@ -27,7 +28,6 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import java.util.ArrayList; import java.util.Arrays; @@ -38,12 +38,12 @@ import java.util.Map; import org.apache.beam.fn.harness.HandlesSplits; import org.apache.beam.fn.harness.control.BundleProgressReporter; import org.apache.beam.fn.harness.control.ExecutionStateSampler; +import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; import org.apache.beam.model.fnexecution.v1.BeamFnApi.ProcessBundleDescriptor; import org.apache.beam.model.pipeline.v1.MetricsApi.MonitoringInfo; import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection; import org.apache.beam.runners.core.construction.SdkComponents; import org.apache.beam.runners.core.metrics.DistributionData; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Labels; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns; @@ -52,9 +52,9 @@ import org.apache.beam.runners.core.metrics.SimpleMonitoringInfoBuilder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.fn.data.FnDataReceiver; -import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; -import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.common.ElementByteSizeObservableIterable; @@ -68,8 +68,6 @@ import org.junit.Test; import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.InOrder; -import org.mockito.Mockito; import org.mockito.stubbing.Answer; /** Tests for {@link PCollectionConsumerRegistryTest}. */ @@ -78,6 +76,7 @@ import org.mockito.stubbing.Answer; "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) }) public class PCollectionConsumerRegistryTest { + private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", "bar"); @Rule public ExpectedException expectedException = ExpectedException.none(); @@ -114,6 +113,7 @@ public class PCollectionConsumerRegistryTest { @After public void tearDown() throws Exception { + MetricsEnvironment.setCurrentContainer(null); sampler.stop(); } @@ -121,17 +121,11 @@ public class PCollectionConsumerRegistryTest { public void singleConsumer() throws Exception { final String pTransformIdA = "pTransformIdA"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class); consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1); @@ -183,17 +177,11 @@ public class PCollectionConsumerRegistryTest { final String pTransformId = "pTransformId"; final String message = "testException"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> consumer = mock(FnDataReceiver.class); consumers.register(P_COLLECTION_A, pTransformId, pTransformId + "Name", consumer); @@ -211,17 +199,11 @@ public class PCollectionConsumerRegistryTest { /** Test that the counter increments even when there are no consumers of the PCollection. */ @Test public void noConsumers() throws Exception { - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> wrapperConsumer = (FnDataReceiver<WindowedValue<String>>) @@ -271,17 +253,11 @@ public class PCollectionConsumerRegistryTest { final String pTransformIdA = "pTransformIdA"; final String pTransformIdB = "pTransformIdB"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class); FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class); @@ -336,17 +312,11 @@ public class PCollectionConsumerRegistryTest { final String pTransformId = "pTransformId"; final String message = "testException"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class); FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class); @@ -367,17 +337,11 @@ public class PCollectionConsumerRegistryTest { public void throwsOnRegisteringAfterMultiplexingConsumerWasInitialized() throws Exception { final String pTransformId = "pTransformId"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class); FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class); @@ -391,31 +355,19 @@ public class PCollectionConsumerRegistryTest { @Test public void testMetricContainerUpdatedUponAcceptingElement() throws Exception { - MetricsEnvironmentState metricsEnvironmentState = mock(MetricsEnvironmentState.class); - - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); + ExecutionStateTracker executionStateTracker = sampler.create(); + MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer()); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); + executionStateTracker.start("testBundle"); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - metricsEnvironmentState, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); - FnDataReceiver<WindowedValue<String>> consumerA1 = mock(FnDataReceiver.class); - FnDataReceiver<WindowedValue<String>> consumerA2 = mock(FnDataReceiver.class); + executionStateTracker, shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); - consumers.register(P_COLLECTION_A, "pTransformA", "pTransformAName", consumerA1); - consumers.register(P_COLLECTION_A, "pTransformB", "pTransformBName", consumerA2); - - // Test both cases; when there is an existing container and where there is no container - MetricsContainer oldContainer = mock(MetricsContainer.class); - when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA"))) - .thenReturn(oldContainer); - when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB"))) - .thenReturn(null); + consumers.register( + P_COLLECTION_A, "pTransformA", "pTransformAName", (unused) -> TEST_USER_COUNTER.inc()); + consumers.register( + P_COLLECTION_A, "pTransformB", "pTransformBName", (unused) -> TEST_USER_COUNTER.inc(2)); FnDataReceiver<WindowedValue<String>> wrapperConsumer = (FnDataReceiver<WindowedValue<String>>) @@ -423,36 +375,45 @@ public class PCollectionConsumerRegistryTest { WindowedValue<String> element = valueInGlobalWindow("elem"); wrapperConsumer.accept(element); - - // Verify that metrics environment state is updated with pTransformA's container, then reset to - // the oldContainer, then pTransformB's container and then reset to null. - InOrder inOrder = Mockito.inOrder(metricsEnvironmentState); - inOrder - .verify(metricsEnvironmentState) - .activate(metricsContainerRegistry.getContainer("pTransformA")); - inOrder.verify(metricsEnvironmentState).activate(oldContainer); - inOrder - .verify(metricsEnvironmentState) - .activate(metricsContainerRegistry.getContainer("pTransformB")); - inOrder.verify(metricsEnvironmentState).activate(null); - inOrder.verifyNoMoreInteractions(); + TEST_USER_COUNTER.inc(3); + + // Verify that metrics environment state is updated with pTransform's counters including the + // unbound container when outside the scope of the function + assertEquals( + 1L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getContainer("pTransformA") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 2L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getContainer("pTransformB") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 3L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); } @Test public void testHandlesSplitsPassedToOriginalConsumer() throws Exception { final String pTransformIdA = "pTransformIdA"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); SplittingReceiver consumerA1 = mock(SplittingReceiver.class); consumers.register(P_COLLECTION_A, pTransformIdA, pTransformIdA + "Name", consumerA1); @@ -474,17 +435,11 @@ public class PCollectionConsumerRegistryTest { public void testLazyByteSizeEstimation() throws Exception { final String pTransformIdA = "pTransformIdA"; - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); ShortIdMap shortIds = new ShortIdMap(); BundleProgressReporter.InMemory reporterAndRegistrar = new BundleProgressReporter.InMemory(); PCollectionConsumerRegistry consumers = new PCollectionConsumerRegistry( - metricsContainerRegistry, - MetricsEnvironment::setCurrentContainer, - sampler.create(), - shortIds, - reporterAndRegistrar, - TEST_DESCRIPTOR); + sampler.create(), shortIds, reporterAndRegistrar, TEST_DESCRIPTOR); FnDataReceiver<WindowedValue<Iterable<String>>> consumerA1 = mock(FnDataReceiver.class); consumers.register(P_COLLECTION_B, pTransformIdA, pTransformIdA + "Name", consumerA1); diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java index 7def4286258..6e06c16c653 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/data/PTransformFunctionRegistryTest.java @@ -20,32 +20,28 @@ package org.apache.beam.fn.harness.data; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.fn.harness.control.ExecutionStateSampler; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTracker; import org.apache.beam.fn.harness.control.ExecutionStateSampler.ExecutionStateTrackerStatus; -import org.apache.beam.runners.core.metrics.MetricsContainerStepMap; import org.apache.beam.runners.core.metrics.MonitoringInfoConstants.Urns; import org.apache.beam.runners.core.metrics.ShortIdMap; import org.apache.beam.sdk.function.ThrowingRunnable; -import org.apache.beam.sdk.metrics.MetricsContainer; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.metrics.MetricsEnvironment; -import org.apache.beam.sdk.metrics.MetricsEnvironment.MetricsEnvironmentState; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -import org.mockito.InOrder; -import org.mockito.Mockito; /** Tests for {@link PTransformFunctionRegistry}. */ @RunWith(JUnit4.class) public class PTransformFunctionRegistryTest { + private static final Counter TEST_USER_COUNTER = Metrics.counter("foo", "bar"); private ExecutionStateSampler sampler; @@ -56,19 +52,17 @@ public class PTransformFunctionRegistryTest { @After public void tearDown() { + MetricsEnvironment.setCurrentContainer(null); sampler.stop(); } @Test public void testStateTrackerRecordsStateTransitions() throws Exception { ExecutionStateTracker executionStateTracker = sampler.create(); + MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer()); PTransformFunctionRegistry testObject = new PTransformFunctionRegistry( - mock(MetricsContainerStepMap.class), - MetricsEnvironment::setCurrentContainer, - new ShortIdMap(), - executionStateTracker, - Urns.START_BUNDLE_MSECS); + new ShortIdMap(), executionStateTracker, Urns.START_BUNDLE_MSECS); final AtomicBoolean runnableAWasCalled = new AtomicBoolean(); final AtomicBoolean runnableBWasCalled = new AtomicBoolean(); @@ -111,46 +105,49 @@ public class PTransformFunctionRegistryTest { @Test public void testMetricsUponRunningFunctions() throws Exception { - MetricsEnvironmentState metricsEnvironmentState = mock(MetricsEnvironmentState.class); ExecutionStateTracker executionStateTracker = sampler.create(); - MetricsContainerStepMap metricsContainerRegistry = new MetricsContainerStepMap(); + MetricsEnvironment.setCurrentContainer(executionStateTracker.getMetricsContainer()); PTransformFunctionRegistry testObject = new PTransformFunctionRegistry( - metricsContainerRegistry, - metricsEnvironmentState, - new ShortIdMap(), - executionStateTracker, - Urns.START_BUNDLE_MSECS); + new ShortIdMap(), executionStateTracker, Urns.START_BUNDLE_MSECS); - ThrowingRunnable runnableA = mock(ThrowingRunnable.class); - ThrowingRunnable runnableB = mock(ThrowingRunnable.class); - testObject.register("pTransformA", "pTranformAName", runnableA); - testObject.register("pTransformB", "pTranformBName", runnableB); + testObject.register("pTransformA", "pTranformAName", () -> TEST_USER_COUNTER.inc()); + testObject.register("pTransformB", "pTranformBName", () -> TEST_USER_COUNTER.inc(2)); // Test both cases; when there is an existing container and where there is no container - MetricsContainer oldContainer = mock(MetricsContainer.class); - when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformA"))) - .thenReturn(oldContainer); - when(metricsEnvironmentState.activate(metricsContainerRegistry.getContainer("pTransformB"))) - .thenReturn(null); - executionStateTracker.start("testBundleId"); for (ThrowingRunnable func : testObject.getFunctions()) { func.run(); } - executionStateTracker.reset(); + TEST_USER_COUNTER.inc(3); + + // Verify that metrics environment state is updated with pTransform's counters including the + // unbound container when outside the scope of the function + assertEquals( + 1L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getContainer("pTransformA") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 2L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getContainer("pTransformB") + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); + assertEquals( + 3L, + (long) + executionStateTracker + .getMetricsContainerRegistry() + .getUnboundContainer() + .getCounter(TEST_USER_COUNTER.getName()) + .getCumulative()); - // Verify that metrics environment state is updated with pTransformA's container, then reset to - // the oldContainer, then pTransformB's container and then reset to null. - InOrder inOrder = Mockito.inOrder(metricsEnvironmentState); - inOrder - .verify(metricsEnvironmentState) - .activate(metricsContainerRegistry.getContainer("pTransformA")); - inOrder.verify(metricsEnvironmentState).activate(oldContainer); - inOrder - .verify(metricsEnvironmentState) - .activate(metricsContainerRegistry.getContainer("pTransformB")); - inOrder.verify(metricsEnvironmentState).activate(null); - inOrder.verifyNoMoreInteractions(); + executionStateTracker.reset(); } }