Repository: incubator-beam Updated Branches: refs/heads/master 86d222aab -> 0a413e78e
Perform initial splitting in the DirectRunner This allows sources to be read from in parallel and generates initial splits. Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/f68fea02 Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/f68fea02 Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/f68fea02 Branch: refs/heads/master Commit: f68fea02b63e5844b9ccbd31ff8e02da407f65b7 Parents: 86d222a Author: Thomas Groh <tg...@google.com> Authored: Wed Oct 5 16:11:21 2016 -0700 Committer: Luke Cwik <lc...@google.com> Committed: Fri Oct 14 13:54:55 2016 -0700 ---------------------------------------------------------------------- .../direct/BoundedReadEvaluatorFactory.java | 40 ++- .../beam/runners/direct/DirectOptions.java | 23 ++ .../beam/runners/direct/DirectRunner.java | 11 +- .../beam/runners/direct/EmptyInputProvider.java | 12 +- .../direct/ExecutorServiceParallelExecutor.java | 15 +- .../beam/runners/direct/RootInputProvider.java | 7 +- .../runners/direct/RootProviderRegistry.java | 5 +- .../direct/TestStreamEvaluatorFactory.java | 4 +- .../direct/TransformEvaluatorRegistry.java | 10 +- .../direct/UnboundedReadEvaluatorFactory.java | 35 ++- .../beam/runners/direct/WatermarkManager.java | 1 + .../direct/BoundedReadEvaluatorFactoryTest.java | 41 ++- .../direct/FlattenEvaluatorFactoryTest.java | 9 +- .../direct/TestStreamEvaluatorFactoryTest.java | 2 +- .../UnboundedReadEvaluatorFactoryTest.java | 55 +++- .../sdk/io/gcp/bigquery/BigQueryAvroUtils.java | 69 ++++- .../io/gcp/bigquery/BigQueryAvroUtilsTest.java | 132 +++++++-- .../sdk/io/gcp/bigquery/BigQueryIOTest.java | 292 ++++++++++++++++++- .../sdk/io/gcp/bigtable/BigtableIOTest.java | 9 +- 19 files changed, 662 insertions(+), 110 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java index 326a535..843dcd6 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactory.java @@ -18,28 +18,32 @@ package org.apache.beam.runners.direct; import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.Collection; -import java.util.Collections; +import java.util.List; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.runners.direct.StepTransformResult.Builder; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.BoundedSource.BoundedReader; -import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.Read.Bounded; +import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} * for the {@link Bounded Read.Bounded} primitive {@link PTransform}. */ final class BoundedReadEvaluatorFactory implements TransformEvaluatorFactory { + private static final Logger LOG = LoggerFactory.getLogger(BoundedReadEvaluatorFactory.class); private final EvaluationContext evaluationContext; BoundedReadEvaluatorFactory(EvaluationContext evaluationContext) { @@ -126,18 +130,32 @@ final class BoundedReadEvaluatorFactory implements TransformEvaluatorFactory { } @Override - public Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform) { - return createInitialSplits((AppliedPTransform) transform); + public Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws Exception { + return createInitialSplits((AppliedPTransform) transform, targetParallelism); } - private <OutputT> Collection<CommittedBundle<?>> createInitialSplits( - AppliedPTransform<?, ?, Read.Bounded<OutputT>> transform) { + private <OutputT> + Collection<CommittedBundle<BoundedSourceShard<OutputT>>> createInitialSplits( + AppliedPTransform<?, ?, Bounded<OutputT>> transform, int targetParallelism) + throws Exception { BoundedSource<OutputT> source = transform.getTransform().getSource(); - return Collections.<CommittedBundle<?>>singleton( - evaluationContext - .<BoundedSourceShard<OutputT>>createRootBundle() - .add(WindowedValue.valueInGlobalWindow(BoundedSourceShard.of(source))) - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + PipelineOptions options = evaluationContext.getPipelineOptions(); + long estimatedBytes = source.getEstimatedSizeBytes(options); + long bytesPerBundle = estimatedBytes / targetParallelism; + List<? extends BoundedSource<OutputT>> bundles = + source.splitIntoBundles(bytesPerBundle, options); + ImmutableList.Builder<CommittedBundle<BoundedSourceShard<OutputT>>> shards = + ImmutableList.builder(); + for (BoundedSource<OutputT> bundle : bundles) { + CommittedBundle<BoundedSourceShard<OutputT>> inputShard = + evaluationContext + .<BoundedSourceShard<OutputT>>createRootBundle() + .add(WindowedValue.valueInGlobalWindow(BoundedSourceShard.of(bundle))) + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE); + shards.add(inputShard); + } + return shards.build(); } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java index 89e1bb8..b2c4f47 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectOptions.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.direct; import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.PipelineOptions; @@ -62,4 +63,26 @@ public interface DirectOptions extends PipelineOptions, ApplicationNameOptions { + "PCollection are encodable. All elements in a PCollection must be encodable.") boolean isEnforceEncodability(); void setEnforceEncodability(boolean test); + + @Default.InstanceFactory(AvailableParallelismFactory.class) + @Description( + "Controls the amount of target parallelism the DirectRunner will use. Defaults to" + + " the greater of the number of available processors and 3. Must be a value greater" + + " than zero.") + int getTargetParallelism(); + void setTargetParallelism(int target); + + /** + * A {@link DefaultValueFactory} that returns the result of {@link Runtime#availableProcessors()} + * from the {@link #create(PipelineOptions)} method. Uses {@link Runtime#getRuntime()} to obtain + * the {@link Runtime}. + */ + class AvailableParallelismFactory implements DefaultValueFactory<Integer> { + private static final int MIN_PARALLELISM = 3; + + @Override + public Integer create(PipelineOptions options) { + return Math.max(Runtime.getRuntime().availableProcessors(), MIN_PARALLELISM); + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java index 8941093..6ef2472 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java @@ -186,7 +186,7 @@ public class DirectRunner //////////////////////////////////////////////////////////////////////////////////////////////// private final DirectOptions options; - private Supplier<ExecutorService> executorServiceSupplier = new FixedThreadPoolSupplier(); + private Supplier<ExecutorService> executorServiceSupplier; private Supplier<Clock> clockSupplier = new NanosOffsetClockSupplier(); public static DirectRunner fromOptions(PipelineOptions options) { @@ -195,6 +195,7 @@ public class DirectRunner private DirectRunner(DirectOptions options) { this.options = options; + this.executorServiceSupplier = new FixedThreadPoolSupplier(options); } /** @@ -440,9 +441,15 @@ public class DirectRunner * {@link Executors#newFixedThreadPool(int)}. */ private static class FixedThreadPoolSupplier implements Supplier<ExecutorService> { + private final DirectOptions options; + + private FixedThreadPoolSupplier(DirectOptions options) { + this.options = options; + } + @Override public ExecutorService get() { - return Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors()); + return Executors.newFixedThreadPool(options.getTargetParallelism()); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java index 10d63e9..1058943 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/EmptyInputProvider.java @@ -21,8 +21,6 @@ import java.util.Collection; import java.util.Collections; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; /** * A {@link RootInputProvider} that provides a singleton empty bundle. @@ -37,13 +35,11 @@ class EmptyInputProvider implements RootInputProvider { /** * {@inheritDoc}. * - * <p>Returns a single empty bundle. This bundle ensures that any {@link PTransform PTransforms} - * that consume from the output of the provided {@link AppliedPTransform} have watermarks updated - * as appropriate. + * <p>Returns an empty collection. */ @Override - public Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform) { - return Collections.<CommittedBundle<?>>singleton( - evaluationContext.createRootBundle().commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + public Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) { + return Collections.emptyList(); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java index 3761574..3274524 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ExecutorServiceParallelExecutor.java @@ -17,8 +17,6 @@ */ package org.apache.beam.runners.direct; -import static com.google.common.base.Preconditions.checkState; - import com.google.auto.value.AutoValue; import com.google.common.base.MoreObjects; import com.google.common.base.Optional; @@ -51,6 +49,7 @@ import org.apache.beam.sdk.util.KeyedWorkItem; import org.apache.beam.sdk.util.KeyedWorkItems; import org.apache.beam.sdk.util.TimeDomain; import org.apache.beam.sdk.util.TimerInternals.TimerData; +import org.apache.beam.sdk.util.UserCodeException; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; @@ -166,12 +165,12 @@ final class ExecutorServiceParallelExecutor implements PipelineExecutor { public void start(Collection<AppliedPTransform<?, ?, ?>> roots) { for (AppliedPTransform<?, ?, ?> root : roots) { ConcurrentLinkedQueue<CommittedBundle<?>> pending = new ConcurrentLinkedQueue<>(); - Collection<CommittedBundle<?>> initialInputs = rootInputProvider.getInitialInputs(root); - checkState( - !initialInputs.isEmpty(), - "All root transforms must have initial inputs. Got 0 for %s", - root.getFullName()); - pending.addAll(initialInputs); + try { + Collection<CommittedBundle<?>> initialInputs = rootInputProvider.getInitialInputs(root, 1); + pending.addAll(initialInputs); + } catch (Exception e) { + throw UserCodeException.wrap(e); + } pendingRootBundles.put(root, pending); } evaluationContext.initialize(pendingRootBundles); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java index 40c7301..19d0040 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootInputProvider.java @@ -36,6 +36,11 @@ interface RootInputProvider { * <p>For source transforms, these should be sufficient that, when provided to the evaluators * produced by {@link TransformEvaluatorFactory#forApplication(AppliedPTransform, * CommittedBundle)}, all of the elements contained in the source are eventually produced. + * + * @param transform the {@link AppliedPTransform} to get initial inputs for. + * @param targetParallelism the target amount of parallelism to obtain from the source. Must be + * greater than or equal to 1. */ - Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform); + Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws Exception; } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java index f6335fd..bb5fcd2 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/RootProviderRegistry.java @@ -52,7 +52,8 @@ class RootProviderRegistry implements RootInputProvider { } @Override - public Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform) { + public Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws Exception { Class<? extends PTransform> transformClass = transform.getTransform().getClass(); RootInputProvider provider = checkNotNull( @@ -60,6 +61,6 @@ class RootProviderRegistry implements RootInputProvider { "Tried to get a %s for a Transform of type %s, but there is no such provider", RootInputProvider.class.getSimpleName(), transformClass.getSimpleName()); - return provider.getInitialInputs(transform); + return provider.getInitialInputs(transform, targetParallelism); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java index 4a48a58..065adc1 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactory.java @@ -199,7 +199,8 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { } @Override - public Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform) { + public Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) { return createInputBundle((AppliedPTransform) transform); } @@ -213,6 +214,7 @@ class TestStreamEvaluatorFactory implements TransformEvaluatorFactory { return Collections.<CommittedBundle<?>>singleton(initialBundle); } } + @AutoValue abstract static class TestStreamIndex<T> { static <T> TestStreamIndex<T> of(TestStream<T> stream) { http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java index 4b495e6..3dd44a7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/TransformEvaluatorRegistry.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.direct; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import com.google.common.collect.ImmutableMap; @@ -81,14 +82,13 @@ class TransformEvaluatorRegistry implements TransformEvaluatorFactory { throws Exception { checkState( !finished.get(), "Tried to get an evaluator for a finished TransformEvaluatorRegistry"); - TransformEvaluatorFactory factory = getFactory(application); + Class<? extends PTransform> transformClass = application.getTransform().getClass(); + TransformEvaluatorFactory factory = + checkNotNull( + factories.get(transformClass), "No evaluator for PTransform type %s", transformClass); return factory.forApplication(application, inputBundle); } - private TransformEvaluatorFactory getFactory(AppliedPTransform<?, ?, ?> application) { - return factories.get(application.getTransform().getClass()); - } - @Override public void cleanup() throws Exception { Collection<Exception> thrownInCleanup = new ArrayList<>(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java index 08dc286..18d3d0a 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactory.java @@ -19,9 +19,11 @@ package org.apache.beam.runners.direct; import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import java.io.IOException; import java.util.Collection; import java.util.Collections; +import java.util.List; import java.util.concurrent.ThreadLocalRandom; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; @@ -33,20 +35,22 @@ import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.io.UnboundedSource.UnboundedReader; import org.apache.beam.sdk.transforms.AppliedPTransform; -import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A {@link TransformEvaluatorFactory} that produces {@link TransformEvaluator TransformEvaluators} * for the {@link Unbounded Read.Unbounded} primitive {@link PTransform}. */ class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory { + private static final Logger LOG = LoggerFactory.getLogger(UnboundedReadEvaluatorFactory.class); // Occasionally close an existing reader and resume from checkpoint, to exercise close-and-resume - @VisibleForTesting static final double DEFAULT_READER_REUSE_CHANCE = 0.95; + private static final double DEFAULT_READER_REUSE_CHANCE = 0.95; private final EvaluationContext evaluationContext; private final double readerReuseChance; @@ -253,24 +257,33 @@ class UnboundedReadEvaluatorFactory implements TransformEvaluatorFactory { } @Override - public Collection<CommittedBundle<?>> getInitialInputs(AppliedPTransform<?, ?, ?> transform) { - return createInitialSplits((AppliedPTransform) transform); + public Collection<CommittedBundle<?>> getInitialInputs( + AppliedPTransform<?, ?, ?> transform, int targetParallelism) throws Exception { + return createInitialSplits((AppliedPTransform) transform, targetParallelism); } private <OutputT> Collection<CommittedBundle<?>> createInitialSplits( - AppliedPTransform<?, ?, Read.Unbounded<OutputT>> transform) { + AppliedPTransform<?, ?, Unbounded<OutputT>> transform, int targetParallelism) + throws Exception { UnboundedSource<OutputT, ?> source = transform.getTransform().getSource(); + List<? extends UnboundedSource<OutputT, ?>> splits = + source.generateInitialSplits(targetParallelism, evaluationContext.getPipelineOptions()); UnboundedReadDeduplicator deduplicator = source.requiresDeduping() ? UnboundedReadDeduplicator.CachedIdDeduplicator.create() : NeverDeduplicator.create(); - UnboundedSourceShard<OutputT, ?> shard = UnboundedSourceShard.unstarted(source, deduplicator); - return Collections.<CommittedBundle<?>>singleton( - evaluationContext - .<UnboundedSourceShard<?, ?>>createRootBundle() - .add(WindowedValue.<UnboundedSourceShard<?, ?>>valueInGlobalWindow(shard)) - .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + ImmutableList.Builder<CommittedBundle<?>> initialShards = ImmutableList.builder(); + for (UnboundedSource<OutputT, ?> split : splits) { + UnboundedSourceShard<OutputT, ?> shard = + UnboundedSourceShard.unstarted(split, deduplicator); + initialShards.add( + evaluationContext + .<UnboundedSourceShard<?, ?>>createRootBundle() + .add(WindowedValue.<UnboundedSourceShard<?, ?>>valueInGlobalWindow(shard)) + .commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); + } + return initialShards.build(); } } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java index 82a6e4f..c55a036 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WatermarkManager.java @@ -799,6 +799,7 @@ public class WatermarkManager { for (CommittedBundle<?> initialBundle : rootEntry.getValue()) { rootWms.addPending(initialBundle); } + pendingRefreshes.offer(rootEntry.getKey()); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java index ee17eae..8a76a53 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/BoundedReadEvaluatorFactoryTest.java @@ -17,10 +17,14 @@ */ package org.apache.beam.runners.direct; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.when; @@ -43,9 +47,11 @@ import org.apache.beam.sdk.io.CountingSource; import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; import org.hamcrest.Matchers; @@ -56,6 +62,8 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; /** * Tests for {@link BoundedReadEvaluatorFactory}. @@ -87,7 +95,7 @@ public class BoundedReadEvaluatorFactoryTest { Collection<CommittedBundle<?>> initialInputs = new BoundedReadEvaluatorFactory.InputProvider(context) - .getInitialInputs(longs.getProducingTransformInternal()); + .getInitialInputs(longs.getProducingTransformInternal(), 1); List<WindowedValue<?>> outputs = new ArrayList<>(); for (CommittedBundle<?> shardBundle : initialInputs) { TransformEvaluator<?> evaluator = @@ -115,6 +123,37 @@ public class BoundedReadEvaluatorFactoryTest { } @Test + public void getInitialInputsSplitsIntoBundles() throws Exception { + when(context.createRootBundle()) + .thenAnswer( + new Answer<UncommittedBundle<?>>() { + @Override + public UncommittedBundle<?> answer(InvocationOnMock invocation) throws Throwable { + return bundleFactory.createRootBundle(); + } + }); + Collection<CommittedBundle<?>> initialInputs = + new BoundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(longs.getProducingTransformInternal(), 3); + + assertThat(initialInputs, hasSize(allOf(greaterThanOrEqualTo(3), lessThanOrEqualTo(4)))); + + Collection<BoundedSource<Long>> sources = new ArrayList<>(); + for (CommittedBundle<?> initialInput : initialInputs) { + Iterable<WindowedValue<BoundedSourceShard<Long>>> shards = + (Iterable) initialInput.getElements(); + WindowedValue<BoundedSourceShard<Long>> shard = Iterables.getOnlyElement(shards); + assertThat(shard.getWindows(), Matchers.<BoundedWindow>contains(GlobalWindow.INSTANCE)); + assertThat(shard.getTimestamp(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + sources.add(shard.getValue().getSource()); + } + + SourceTestUtils.assertSourcesEqualReferenceSource(source, + (List<? extends BoundedSource<Long>>) sources, + PipelineOptionsFactory.create()); + } + + @Test public void boundedSourceInMemoryTransformEvaluatorShardsOfSource() throws Exception { PipelineOptions options = PipelineOptionsFactory.create(); List<? extends BoundedSource<Long>> splits = http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java index aa7b178..417aa64 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/FlattenEvaluatorFactoryTest.java @@ -24,13 +24,13 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import com.google.common.collect.Iterables; -import java.util.Collection; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.Flatten; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.PCollection; @@ -122,17 +122,14 @@ public class FlattenEvaluatorFactoryTest { PCollection<Integer> flattened = list.apply(Flatten.<Integer>pCollections()); EvaluationContext evaluationContext = mock(EvaluationContext.class); - when(evaluationContext.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); when(evaluationContext.createBundle(flattened)) .thenReturn(bundleFactory.createBundle(flattened)); FlattenEvaluatorFactory factory = new FlattenEvaluatorFactory(evaluationContext); - Collection<CommittedBundle<?>> initialInputs = - new EmptyInputProvider(evaluationContext) - .getInitialInputs(flattened.getProducingTransformInternal()); TransformEvaluator<Integer> emptyEvaluator = factory.forApplication( - flattened.getProducingTransformInternal(), Iterables.getOnlyElement(initialInputs)); + flattened.getProducingTransformInternal(), + bundleFactory.createRootBundle().commit(BoundedWindow.TIMESTAMP_MAX_VALUE)); TransformResult leftSideResult = emptyEvaluator.finishBundle(); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java index 60b9c79..94a0d41 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/TestStreamEvaluatorFactoryTest.java @@ -82,7 +82,7 @@ public class TestStreamEvaluatorFactoryTest { Collection<CommittedBundle<?>> initialInputs = new TestStreamEvaluatorFactory.InputProvider(context) - .getInitialInputs(streamVals.getProducingTransformInternal()); + .getInitialInputs(streamVals.getProducingTransformInternal(), 1); @SuppressWarnings("unchecked") CommittedBundle<TestStreamIndex<Integer>> initialBundle = (CommittedBundle<TestStreamIndex<Integer>>) Iterables.getOnlyElement(initialInputs); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java index b78fbeb..76acb03 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/UnboundedReadEvaluatorFactoryTest.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.direct; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.junit.Assert.assertThat; @@ -33,10 +34,12 @@ import com.google.common.collect.Range; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.NoSuchElementException; +import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.direct.DirectRunner.CommittedBundle; import org.apache.beam.runners.direct.DirectRunner.UncommittedBundle; @@ -51,9 +54,12 @@ import org.apache.beam.sdk.io.Read; import org.apache.beam.sdk.io.UnboundedSource; import org.apache.beam.sdk.io.UnboundedSource.CheckpointMark; import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.testing.SourceTestUtils; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.AppliedPTransform; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.WindowedValue; @@ -66,6 +72,9 @@ import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + /** * Tests for {@link UnboundedReadEvaluatorFactory}. */ @@ -93,12 +102,51 @@ public class UnboundedReadEvaluatorFactoryTest { } @Test + public void generatesInitialSplits() throws Exception { + when(context.createRootBundle()).thenAnswer(new Answer<UncommittedBundle<?>>() { + @Override + public UncommittedBundle<?> answer(InvocationOnMock invocation) throws Throwable { + return bundleFactory.createRootBundle(); + } + }); + + int numSplits = 5; + Collection<CommittedBundle<?>> initialInputs = + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(longs.getProducingTransformInternal(), numSplits); + // CountingSource.unbounded has very good splitting behavior + assertThat(initialInputs, hasSize(numSplits)); + + int readPerSplit = 100; + int totalSize = numSplits * readPerSplit; + Set<Long> expectedOutputs = + ContiguousSet.create(Range.closedOpen(0L, (long) totalSize), DiscreteDomain.longs()); + + Collection<Long> readItems = new ArrayList<>(totalSize); + for (CommittedBundle<?> initialInput : initialInputs) { + CommittedBundle<UnboundedSourceShard<Long, ?>> shardBundle = + (CommittedBundle<UnboundedSourceShard<Long, ?>>) initialInput; + WindowedValue<UnboundedSourceShard<Long, ?>> shard = + Iterables.getOnlyElement(shardBundle.getElements()); + assertThat(shard.getTimestamp(), equalTo(BoundedWindow.TIMESTAMP_MIN_VALUE)); + assertThat(shard.getWindows(), Matchers.<BoundedWindow>contains(GlobalWindow.INSTANCE)); + UnboundedSource<Long, ?> shardSource = shard.getValue().getSource(); + readItems.addAll( + SourceTestUtils.readNItemsFromUnstartedReader( + shardSource.createReader( + PipelineOptionsFactory.create(), null /* No starting checkpoint */), + readPerSplit)); + } + assertThat(readItems, containsInAnyOrder(expectedOutputs.toArray(new Long[0]))); + } + + @Test public void unboundedSourceInMemoryTransformEvaluatorProducesElements() throws Exception { when(context.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); Collection<CommittedBundle<?>> initialInputs = new UnboundedReadEvaluatorFactory.InputProvider(context) - .getInitialInputs(longs.getProducingTransformInternal()); + .getInitialInputs(longs.getProducingTransformInternal(), 1); CommittedBundle<?> inputShards = Iterables.getOnlyElement(initialInputs); UnboundedSourceShard<Long, ?> inputShard = @@ -143,7 +191,8 @@ public class UnboundedReadEvaluatorFactoryTest { when(context.createRootBundle()).thenReturn(bundleFactory.createRootBundle()); Collection<CommittedBundle<?>> initialInputs = - new UnboundedReadEvaluatorFactory.InputProvider(context).getInitialInputs(sourceTransform); + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(sourceTransform, 1); UncommittedBundle<Long> output = bundleFactory.createBundle(pcollection); when(context.createBundle(pcollection)).thenReturn(output); @@ -198,6 +247,8 @@ public class UnboundedReadEvaluatorFactoryTest { .commit(Instant.now()); UnboundedReadEvaluatorFactory factory = new UnboundedReadEvaluatorFactory(context, 1.0 /* Always reuse */); + new UnboundedReadEvaluatorFactory.InputProvider(context) + .getInitialInputs(pcollection.getProducingTransformInternal(), 1); TransformEvaluator<UnboundedSourceShard<Long, TestCheckpointMark>> evaluator = factory.forApplication(sourceTransform, inputBundle); evaluator.processElement(shard); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java ---------------------------------------------------------------------- diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java index 6a9ea6b..20dd2d0 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtils.java @@ -28,8 +28,8 @@ import com.google.api.services.bigquery.model.TableSchema; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.io.BaseEncoding; - import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; import org.apache.avro.Schema; @@ -47,6 +47,19 @@ import org.joda.time.format.DateTimeFormatter; */ class BigQueryAvroUtils { + public static final ImmutableMap<String, Type> BIG_QUERY_TO_AVRO_TYPES = + ImmutableMap.<String, Type>builder() + .put("STRING", Type.STRING) + .put("BYTES", Type.BYTES) + .put("INTEGER", Type.LONG) + .put("FLOAT", Type.DOUBLE) + .put("BOOLEAN", Type.BOOLEAN) + .put("TIMESTAMP", Type.LONG) + .put("RECORD", Type.RECORD) + .put("DATE", Type.STRING) + .put("DATETIME", Type.STRING) + .put("TIME", Type.STRING) + .build(); /** * Formats BigQuery seconds-since-epoch into String matching JSON export. Thread-safe and * immutable. @@ -154,23 +167,10 @@ class BigQueryAvroUtils { // REQUIRED fields are represented as the corresponding Avro types. For example, a BigQuery // INTEGER type maps to an Avro LONG type. checkNotNull(v, "REQUIRED field %s should not be null", fieldSchema.getName()); - ImmutableMap<String, Type> fieldMap = - ImmutableMap.<String, Type>builder() - .put("STRING", Type.STRING) - .put("BYTES", Type.BYTES) - .put("INTEGER", Type.LONG) - .put("FLOAT", Type.DOUBLE) - .put("BOOLEAN", Type.BOOLEAN) - .put("TIMESTAMP", Type.LONG) - .put("RECORD", Type.RECORD) - .put("DATE", Type.STRING) - .put("DATETIME", Type.STRING) - .put("TIME", Type.STRING) - .build(); // Per https://cloud.google.com/bigquery/docs/reference/v2/tables#schema, the type field // is required, so it may not be null. String bqType = fieldSchema.getType(); - Type expectedAvroType = fieldMap.get(bqType); + Type expectedAvroType = BIG_QUERY_TO_AVRO_TYPES.get(bqType); verifyNotNull(expectedAvroType, "Unsupported BigQuery type: %s", bqType); verify( avroType == expectedAvroType, @@ -248,4 +248,43 @@ class BigQueryAvroUtils { } return convertRequiredField(unionTypes.get(1).getType(), fieldSchema, v); } + + static Schema toGenericAvroSchema(String schemaName, List<TableFieldSchema> fieldSchemas) { + List<Field> avroFields = new ArrayList<>(); + for (TableFieldSchema bigQueryField : fieldSchemas) { + avroFields.add(convertField(bigQueryField)); + } + return Schema.createRecord( + schemaName, + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for " + schemaName, + false, + avroFields); + } + + private static Field convertField(TableFieldSchema bigQueryField) { + Type avroType = BIG_QUERY_TO_AVRO_TYPES.get(bigQueryField.getType()); + Schema elementSchema; + if (avroType == Type.RECORD) { + elementSchema = toGenericAvroSchema(bigQueryField.getName(), bigQueryField.getFields()); + } else { + elementSchema = Schema.create(avroType); + } + Schema fieldSchema; + if (bigQueryField.getMode() == null || bigQueryField.getMode().equals("NULLABLE")) { + fieldSchema = Schema.createUnion(Schema.create(Type.NULL), elementSchema); + } else if (bigQueryField.getMode().equals("REQUIRED")) { + fieldSchema = elementSchema; + } else if (bigQueryField.getMode().equals("REPEATED")) { + fieldSchema = Schema.createArray(elementSchema); + } else { + throw new IllegalArgumentException( + String.format("Unknown BigQuery Field Mode: %s", bigQueryField.getMode())); + } + return new Field( + bigQueryField.getName(), + fieldSchema, + bigQueryField.getDescription(), + (Object) null /* Cast to avoid deprecated JsonNode constructor. */); + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java index 1d3ea81..644c545 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryAvroUtilsTest.java @@ -17,18 +17,22 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import static org.hamcrest.Matchers.equalTo; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.io.BaseEncoding; - import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import org.apache.avro.Schema; +import org.apache.avro.Schema.Field; +import org.apache.avro.Schema.Type; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import org.apache.avro.reflect.Nullable; @@ -44,36 +48,37 @@ import org.junit.runners.JUnit4; */ @RunWith(JUnit4.class) public class BigQueryAvroUtilsTest { + private List<TableFieldSchema> subFields = Lists.<TableFieldSchema>newArrayList( + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); + /* + * Note that the quality and quantity fields do not have their mode set, so they should default + * to NULLABLE. This is an important test of BigQuery semantics. + * + * All the other fields we set in this function are required on the Schema response. + * + * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema + */ + private List<TableFieldSchema> fields = + Lists.newArrayList( + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), + new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), + new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, + new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, + new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), + new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), + new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryDatetime") + .setType("DATETIME").setMode("NULLABLE"), + new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), + new TableFieldSchema().setName("scion").setType("RECORD").setMode("NULLABLE") + .setFields(subFields), + new TableFieldSchema().setName("associates").setType("RECORD").setMode("REPEATED") + .setFields(subFields)); + @Test public void testConvertGenericRecordToTableRow() throws Exception { TableSchema tableSchema = new TableSchema(); - List<TableFieldSchema> subFields = Lists.<TableFieldSchema>newArrayList( - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE")); - /* - * Note that the quality and quantity fields do not have their mode set, so they should default - * to NULLABLE. This is an important test of BigQuery semantics. - * - * All the other fields we set in this function are required on the Schema response. - * - * See https://cloud.google.com/bigquery/docs/reference/v2/tables#schema - */ - List<TableFieldSchema> fields = - Lists.<TableFieldSchema>newArrayList( - new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"), - new TableFieldSchema().setName("species").setType("STRING").setMode("NULLABLE"), - new TableFieldSchema().setName("quality").setType("FLOAT") /* default to NULLABLE */, - new TableFieldSchema().setName("quantity").setType("INTEGER") /* default to NULLABLE */, - new TableFieldSchema().setName("birthday").setType("TIMESTAMP").setMode("NULLABLE"), - new TableFieldSchema().setName("flighted").setType("BOOLEAN").setMode("NULLABLE"), - new TableFieldSchema().setName("sound").setType("BYTES").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDate").setType("DATE").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryDatetime") - .setType("DATETIME").setMode("NULLABLE"), - new TableFieldSchema().setName("anniversaryTime").setType("TIME").setMode("NULLABLE"), - new TableFieldSchema().setName("scion").setType("RECORD").setMode("NULLABLE") - .setFields(subFields), - new TableFieldSchema().setName("associates").setType("RECORD").setMode("REPEATED") - .setFields(subFields)); tableSchema.setFields(fields); Schema avroSchema = AvroCoder.of(Bird.class).getSchema(); @@ -132,6 +137,77 @@ public class BigQueryAvroUtilsTest { } } + @Test + public void testConvertBigQuerySchemaToAvroSchema() { + TableSchema tableSchema = new TableSchema(); + tableSchema.setFields(fields); + Schema avroSchema = + BigQueryAvroUtils.toGenericAvroSchema("testSchema", tableSchema.getFields()); + + assertThat(avroSchema.getField("number").schema(), equalTo(Schema.create(Type.LONG))); + assertThat( + avroSchema.getField("species").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("quality").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.DOUBLE)))); + assertThat( + avroSchema.getField("quantity").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); + assertThat( + avroSchema.getField("birthday").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.LONG)))); + assertThat( + avroSchema.getField("flighted").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BOOLEAN)))); + assertThat( + avroSchema.getField("sound").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.BYTES)))); + assertThat( + avroSchema.getField("anniversaryDate").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("anniversaryDatetime").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + assertThat( + avroSchema.getField("anniversaryTime").schema(), + equalTo(Schema.createUnion(Schema.create(Type.NULL), Schema.create(Type.STRING)))); + + assertThat( + avroSchema.getField("scion").schema(), + equalTo( + Schema.createUnion( + Schema.create(Type.NULL), + Schema.createRecord( + "scion", + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for scion", + false, + ImmutableList.of( + new Field( + "species", + Schema.createUnion( + Schema.create(Type.NULL), Schema.create(Type.STRING)), + null, + (Object) null)))))); + assertThat( + avroSchema.getField("associates").schema(), + equalTo( + Schema.createArray( + Schema.createRecord( + "associates", + "org.apache.beam.sdk.io.gcp.bigquery", + "Translated Avro Schema for associates", + false, + ImmutableList.of( + new Field( + "species", + Schema.createUnion( + Schema.create(Type.NULL), Schema.create(Type.STRING)), + null, + (Object) null)))))); + } + /** * Pojo class used as the record type in tests. */ http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java index 05a7c5c..9d63611 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTest.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.io.gcp.bigquery; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.fromJsonString; import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.toJsonString; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; @@ -32,7 +34,9 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.when; +import com.google.api.client.json.GenericJson; import com.google.api.client.util.Data; +import com.google.api.services.bigquery.model.Dataset; import com.google.api.services.bigquery.model.ErrorProto; import com.google.api.services.bigquery.model.Job; import com.google.api.services.bigquery.model.JobConfigurationExtract; @@ -50,21 +54,36 @@ import com.google.api.services.bigquery.model.TableReference; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import com.google.common.base.Strings; +import com.google.common.collect.HashBasedTable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; +import com.google.common.collect.Table.Cell; +import java.io.ByteArrayInputStream; import java.io.File; import java.io.FileFilter; import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.io.Serializable; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; import java.nio.file.Files; import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; import java.util.Set; import javax.annotation.Nullable; +import org.apache.avro.Schema; +import org.apache.avro.file.DataFileWriter; +import org.apache.avro.generic.GenericDatumWriter; +import org.apache.avro.generic.GenericRecord; +import org.apache.avro.generic.GenericRecordBuilder; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; @@ -110,7 +129,9 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.IOChannelFactory; import org.apache.beam.sdk.util.IOChannelUtils; +import org.apache.beam.sdk.util.MimeTypes; import org.apache.beam.sdk.util.PCollectionViews; +import org.apache.beam.sdk.util.Transport; import org.apache.beam.sdk.util.WindowingStrategy; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; @@ -144,6 +165,7 @@ public class BigQueryIOTest implements Serializable { Status.SUCCEEDED, new Job().setStatus(new JobStatus()), Status.FAILED, new Job().setStatus(new JobStatus().setErrorResult(new ErrorProto()))); + private static class FakeBigQueryServices implements BigQueryServices { private String[] jsonTableRowReturns = new String[0]; @@ -290,25 +312,25 @@ public class BigQueryIOTest implements Serializable { @Override public void startLoadJob(JobReference jobRef, JobConfigurationLoad loadConfig) throws InterruptedException, IOException { - startJob(jobRef); + startJob(jobRef, loadConfig); } @Override public void startExtractJob(JobReference jobRef, JobConfigurationExtract extractConfig) throws InterruptedException, IOException { - startJob(jobRef); + startJob(jobRef, extractConfig); } @Override public void startQueryJob(JobReference jobRef, JobConfigurationQuery query) throws IOException, InterruptedException { - startJob(jobRef); + startJob(jobRef, query); } @Override public void startCopyJob(JobReference jobRef, JobConfigurationTableCopy copyConfig) throws IOException, InterruptedException { - startJob(jobRef); + startJob(jobRef, copyConfig); } @Override @@ -338,7 +360,8 @@ public class BigQueryIOTest implements Serializable { } } - private void startJob(JobReference jobRef) throws IOException, InterruptedException { + private void startJob(JobReference jobRef, GenericJson config) + throws IOException, InterruptedException { if (!Strings.isNullOrEmpty(executingProject)) { checkArgument( jobRef.getProjectId().equals(executingProject), @@ -352,6 +375,11 @@ public class BigQueryIOTest implements Serializable { throw (IOException) ret; } else if (ret instanceof InterruptedException) { throw (InterruptedException) ret; + } else if (ret instanceof SerializableFunction) { + SerializableFunction<GenericJson, Void> fn = + (SerializableFunction<GenericJson, Void>) ret; + fn.apply(config); + return; } else { return; } @@ -392,6 +420,178 @@ public class BigQueryIOTest implements Serializable { "Exceeded expected number of calls: " + getJobReturns.length); } } + + ////////////////////////////////// SERIALIZATION METHODS //////////////////////////////////// + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeObject(replaceJobsWithBytes(startJobReturns)); + out.writeObject(replaceJobsWithBytes(pollJobReturns)); + out.writeObject(replaceJobsWithBytes(getJobReturns)); + out.writeObject(executingProject); + } + + private Object[] replaceJobsWithBytes(Object[] objs) { + Object[] copy = Arrays.copyOf(objs, objs.length); + for (int i = 0; i < copy.length; i++) { + checkArgument( + copy[i] == null || copy[i] instanceof Serializable || copy[i] instanceof Job, + "Only serializable elements and jobs can be added add to Job Returns"); + if (copy[i] instanceof Job) { + try { + // Job is not serializable, so encode the job as a byte array. + copy[i] = Transport.getJsonFactory().toByteArray(copy[i]); + } catch (IOException e) { + throw new IllegalArgumentException( + String.format("Could not encode Job %s via available JSON factory", copy[i])); + } + } + } + return copy; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + this.startJobReturns = replaceBytesWithJobs(in.readObject()); + this.pollJobReturns = replaceBytesWithJobs(in.readObject()); + this.getJobReturns = replaceBytesWithJobs(in.readObject()); + this.executingProject = (String) in.readObject(); + } + + private Object[] replaceBytesWithJobs(Object obj) throws IOException { + checkState(obj instanceof Object[]); + Object[] objs = (Object[]) obj; + Object[] copy = Arrays.copyOf(objs, objs.length); + for (int i = 0; i < copy.length; i++) { + if (copy[i] instanceof byte[]) { + Job job = Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream((byte[]) copy[i])) + .parse(Job.class); + copy[i] = job; + } + } + return copy; + } + } + + /** A fake dataset service that can be serialized, for use in testReadFromTable. */ + private static class FakeDatasetService implements DatasetService, Serializable { + private com.google.common.collect.Table<String, String, Map<String, Table>> tables = + HashBasedTable.create(); + + public FakeDatasetService withTable( + String projectId, String datasetId, String tableId, Table table) throws IOException { + Map<String, Table> dataset = tables.get(projectId, datasetId); + if (dataset == null) { + dataset = new HashMap<>(); + tables.put(projectId, datasetId, dataset); + } + dataset.put(tableId, table); + return this; + } + + @Override + public Table getTable(String projectId, String datasetId, String tableId) + throws InterruptedException, IOException { + Map<String, Table> dataset = + checkNotNull( + tables.get(projectId, datasetId), + "Tried to get a table %s:%s.%s from %s, but no such table was set", + projectId, + datasetId, + tableId, + FakeDatasetService.class.getSimpleName()); + return checkNotNull(dataset.get(tableId), + "Tried to get a table %s:%s.%s from %s, but no such table was set", + projectId, + datasetId, + tableId, + FakeDatasetService.class.getSimpleName()); + } + + @Override + public void deleteTable(String projectId, String datasetId, String tableId) + throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); + } + + @Override + public boolean isTableEmpty(String projectId, String datasetId, String tableId) + throws IOException, InterruptedException { + Long numBytes = getTable(projectId, datasetId, tableId).getNumBytes(); + return numBytes == null || numBytes == 0L; + } + + @Override + public Dataset getDataset( + String projectId, String datasetId) throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); + } + + @Override + public void createDataset( + String projectId, String datasetId, String location, String description) + throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); + } + + @Override + public void deleteDataset(String projectId, String datasetId) + throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); + } + + @Override + public long insertAll( + TableReference ref, List<TableRow> rowList, @Nullable List<String> insertIdList) + throws IOException, InterruptedException { + throw new UnsupportedOperationException("Unsupported"); + } + + ////////////////////////////////// SERIALIZATION METHODS //////////////////////////////////// + private void writeObject(ObjectOutputStream out) throws IOException { + out.writeObject(replaceTablesWithBytes(this.tables)); + } + + private com.google.common.collect.Table<String, String, Map<String, byte[]>> + replaceTablesWithBytes( + com.google.common.collect.Table<String, String, Map<String, Table>> toCopy) + throws IOException { + com.google.common.collect.Table<String, String, Map<String, byte[]>> copy = + HashBasedTable.create(); + for (Cell<String, String, Map<String, Table>> cell : toCopy.cellSet()) { + HashMap<String, byte[]> dataset = new HashMap<>(); + copy.put(cell.getRowKey(), cell.getColumnKey(), dataset); + for (Map.Entry<String, Table> dsTables : cell.getValue().entrySet()) { + dataset.put( + dsTables.getKey(), Transport.getJsonFactory().toByteArray(dsTables.getValue())); + } + } + return copy; + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + com.google.common.collect.Table<String, String, Map<String, byte[]>> tablesTable = + (com.google.common.collect.Table<String, String, Map<String, byte[]>>) in.readObject(); + this.tables = replaceBytesWithTables(tablesTable); + } + + private com.google.common.collect.Table<String, String, Map<String, Table>> + replaceBytesWithTables( + com.google.common.collect.Table<String, String, Map<String, byte[]>> tablesTable) + throws IOException { + com.google.common.collect.Table<String, String, Map<String, Table>> copy = + HashBasedTable.create(); + for (Cell<String, String, Map<String, byte[]>> cell : tablesTable.cellSet()) { + HashMap<String, Table> dataset = new HashMap<>(); + copy.put(cell.getRowKey(), cell.getColumnKey(), dataset); + for (Map.Entry<String, byte[]> dsTables : cell.getValue().entrySet()) { + Table table = + Transport.getJsonFactory() + .createJsonParser(new ByteArrayInputStream(dsTables.getValue())) + .parse(Table.class); + dataset.put(dsTables.getKey(), table); + } + } + return copy; + } } @Rule public transient ExpectedException thrown = ExpectedException.none(); @@ -627,11 +827,54 @@ public class BigQueryIOTest implements Serializable { bqOptions.setProject("defaultProject"); bqOptions.setTempLocation(testFolder.newFolder("BigQueryIOTest").getAbsolutePath()); + Job job = new Job(); + JobStatus status = new JobStatus(); + job.setStatus(status); + JobStatistics jobStats = new JobStatistics(); + job.setStatistics(jobStats); + JobStatistics4 extract = new JobStatistics4(); + jobStats.setExtract(extract); + extract.setDestinationUriFileCounts(ImmutableList.of(1L)); + + Table sometable = new Table(); + sometable.setSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER")))); + sometable.setNumBytes(1024L * 1024L); + FakeDatasetService fakeDatasetService = + new FakeDatasetService() + .withTable("non-executing-project", "somedataset", "sometable", sometable); + SerializableFunction<Void, Schema> schemaGenerator = + new SerializableFunction<Void, Schema>() { + @Override + public Schema apply(Void input) { + return BigQueryAvroUtils.toGenericAvroSchema( + "sometable", + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INTEGER"))); + } + }; + Collection<Map<String, Object>> records = + ImmutableList.<Map<String, Object>>builder() + .add(ImmutableMap.<String, Object>builder().put("name", "a").put("number", 1L).build()) + .add(ImmutableMap.<String, Object>builder().put("name", "b").put("number", 2L).build()) + .add(ImmutableMap.<String, Object>builder().put("name", "c").put("number", 3L).build()) + .build(); + + SerializableFunction<GenericJson, Void> onStartJob = + new WriteExtractFiles(schemaGenerator, records); + FakeBigQueryServices fakeBqServices = new FakeBigQueryServices() .withJobService(new FakeJobService() - .startJobReturns("done", "done") + .startJobReturns(onStartJob, "done") + .pollJobReturns(job) .getJobReturns((Job) null) .verifyExecutingProject(bqOptions.getProject())) + .withDatasetService(fakeDatasetService) .readerReturns( toJsonString(new TableRow().set("name", "a").set("number", 1)), toJsonString(new TableRow().set("name", "b").set("number", 2)), @@ -1701,4 +1944,41 @@ public class BigQueryIOTest implements Serializable { return pathname.isFile(); }}).length); } + + private class WriteExtractFiles implements SerializableFunction<GenericJson, Void> { + private final SerializableFunction<Void, Schema> schemaGenerator; + private final Collection<Map<String, Object>> records; + + private WriteExtractFiles( + SerializableFunction<Void, Schema> schemaGenerator, + Collection<Map<String, Object>> records) { + this.schemaGenerator = schemaGenerator; + this.records = records; + } + + @Override + public Void apply(GenericJson input) { + List<String> destinations = (List<String>) input.get("destinationUris"); + for (String destination : destinations) { + String newDest = destination.replace("*", "000000000000"); + Schema schema = schemaGenerator.apply(null); + try (WritableByteChannel channel = IOChannelUtils.create(newDest, MimeTypes.BINARY); + DataFileWriter<GenericRecord> tableRowWriter = + new DataFileWriter<>(new GenericDatumWriter<GenericRecord>(schema)) + .create(schema, Channels.newOutputStream(channel))) { + for (Map<String, Object> record : records) { + GenericRecordBuilder genericRecordBuilder = new GenericRecordBuilder(schema); + for (Map.Entry<String, Object> field : record.entrySet()) { + genericRecordBuilder.set(field.getKey(), field.getValue()); + } + tableRowWriter.append(genericRecordBuilder.build()); + } + } catch (IOException e) { + throw new IllegalStateException( + String.format("Could not create destination for extract job %s", destination), e); + } + } + return null; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/f68fea02/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java index f21e6c0..3ca2b64 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigtable/BigtableIOTest.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.io.gcp.bigtable; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Verify.verifyNotNull; import static org.apache.beam.sdk.testing.SourceTestUtils.assertSourcesEqualReferenceSource; import static org.apache.beam.sdk.testing.SourceTestUtils.assertSplitAtFractionExhaustive; @@ -222,6 +223,7 @@ public class BigtableIOTest { public void testReadingEmptyTable() throws Exception { final String table = "TEST-EMPTY-TABLE"; service.createTable(table); + service.setupSampleRowKeys(table, 1, 1L); runReadTest(defaultRead.withTableId(table), new ArrayList<Row>()); logged.verifyInfo("Closing reader after reading 0 records."); @@ -234,8 +236,9 @@ public class BigtableIOTest { final int numRows = 1001; List<Row> testRows = makeTableData(table, numRows); + service.setupSampleRowKeys(table, 3, 1000L); runReadTest(defaultRead.withTableId(table), testRows); - logged.verifyInfo(String.format("Closing reader after reading %d records.", numRows)); + logged.verifyInfo(String.format("Closing reader after reading %d records.", numRows / 3)); } /** A {@link Predicate} that a {@link Row Row's} key matches the given regex. */ @@ -284,6 +287,7 @@ public class BigtableIOTest { ByteKey startKey = ByteKey.copyFrom("key000000100".getBytes()); ByteKey endKey = ByteKey.copyFrom("key000000300".getBytes()); + service.setupSampleRowKeys(table, numRows / 10, "key000000100".length()); // Test prefix: [beginning, startKey). final ByteKeyRange prefixRange = ByteKeyRange.ALL_KEYS.withEndKey(startKey); List<Row> prefixRows = filterToRange(testRows, prefixRange); @@ -336,6 +340,7 @@ public class BigtableIOTest { RowFilter filter = RowFilter.newBuilder().setRowKeyRegexFilter(ByteString.copyFromUtf8(regex)).build(); + service.setupSampleRowKeys(table, 5, 10L); runReadTest( defaultRead.withTableId(table).withRowFilter(filter), Lists.newArrayList(filteredRows)); @@ -743,7 +748,7 @@ public class BigtableIOTest { @Override public List<SampleRowKeysResponse> getSampleRowKeys(BigtableSource source) { List<SampleRowKeysResponse> samples = sampleRowKeys.get(source.getTableId()); - checkArgument(samples != null, "No samples found for table %s", source.getTableId()); + checkNotNull(samples, "No samples found for table %s", source.getTableId()); return samples; }