[ 
https://issues.apache.org/jira/browse/BEAM-4061?focusedWorklogId=99741&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-99741
 ]

ASF GitHub Bot logged work on BEAM-4061:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 08/May/18 21:59
            Start Date: 08/May/18 21:59
    Worklog Time Spent: 10m 
      Work Description: jkff closed pull request #4264: [BEAM-4061] Introduced 
SpannerWriteResult
URL: https://github.com/apache/beam/pull/4264
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
index b950964411c..c3924d861f3 100644
--- 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java
@@ -23,12 +23,12 @@
 import com.google.auto.value.AutoValue;
 import com.google.cloud.ServiceFactory;
 import com.google.cloud.Timestamp;
-import com.google.cloud.spanner.AbortedException;
 import com.google.cloud.spanner.Key;
 import com.google.cloud.spanner.KeySet;
 import com.google.cloud.spanner.Mutation;
 import com.google.cloud.spanner.PartitionOptions;
 import com.google.cloud.spanner.Spanner;
+import com.google.cloud.spanner.SpannerException;
 import com.google.cloud.spanner.SpannerOptions;
 import com.google.cloud.spanner.Statement;
 import com.google.cloud.spanner.Struct;
@@ -47,6 +47,7 @@
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.SerializableCoder;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.options.ValueProvider;
 import org.apache.beam.sdk.transforms.ApproximateQuantiles;
@@ -60,16 +61,15 @@
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.Wait;
 import org.apache.beam.sdk.transforms.display.DisplayData;
-import org.apache.beam.sdk.util.BackOff;
-import org.apache.beam.sdk.util.BackOffUtils;
-import org.apache.beam.sdk.util.FluentBackoff;
-import org.apache.beam.sdk.util.Sleeper;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionTuple;
 import org.apache.beam.sdk.values.PCollectionView;
-import org.apache.beam.sdk.values.PDone;
-import org.joda.time.Duration;
+import org.apache.beam.sdk.values.TupleTag;
+import org.apache.beam.sdk.values.TupleTagList;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * Experimental {@link PTransform Transforms} for reading from and writing to 
<a
@@ -179,6 +179,7 @@
  */
 @Experimental(Experimental.Kind.SOURCE_SINK)
 public class SpannerIO {
+  private static final Logger LOG = LoggerFactory.getLogger(SpannerIO.class);
 
   private static final long DEFAULT_BATCH_SIZE_BYTES = 1024L * 1024L; // 1 MB
   // Max number of mutations to batch together.
@@ -240,6 +241,7 @@ public static Write write() {
         .setSpannerConfig(SpannerConfig.create())
         .setBatchSizeBytes(DEFAULT_BATCH_SIZE_BYTES)
         .setNumSamples(DEFAULT_NUM_SAMPLES)
+        .setFailureMode(FailureMode.FAIL_FAST)
         .build();
   }
 
@@ -643,6 +645,15 @@ public CreateTransaction withTimestampBound(TimestampBound 
timestampBound) {
     }
   }
 
+  /**
+   * A failure handling strategy.
+   */
+  public enum FailureMode {
+    /** Invalid write to Spanner will cause the pipeline to fail. A default 
strategy. */
+    FAIL_FAST,
+    /** Invalid mutations will be returned as part of the result of the write 
transform. */
+    REPORT_FAILURES
+  }
 
   /**
    * A {@link PTransform} that writes {@link Mutation} objects to Google Cloud 
Spanner.
@@ -651,7 +662,7 @@ public CreateTransaction withTimestampBound(TimestampBound 
timestampBound) {
    */
   @Experimental(Experimental.Kind.SOURCE_SINK)
   @AutoValue
-  public abstract static class Write extends PTransform<PCollection<Mutation>, 
PDone> {
+  public abstract static class Write extends PTransform<PCollection<Mutation>, 
SpannerWriteResult> {
 
     abstract SpannerConfig getSpannerConfig();
 
@@ -659,6 +670,8 @@ public CreateTransaction withTimestampBound(TimestampBound 
timestampBound) {
 
     abstract int getNumSamples();
 
+    abstract FailureMode getFailureMode();
+
     @Nullable
      abstract PTransform<PCollection<KV<String, byte[]>>, 
PCollection<KV<String, List<byte[]>>>>
          getSampler();
@@ -674,6 +687,8 @@ public CreateTransaction withTimestampBound(TimestampBound 
timestampBound) {
 
       abstract Builder setNumSamples(int numSamples);
 
+      abstract Builder setFailureMode(FailureMode failureMode);
+
       abstract Builder setSampler(
           PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, 
List<byte[]>>>>
               sampler);
@@ -755,14 +770,18 @@ public Write withBatchSizeBytes(long batchSizeBytes) {
       return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
     }
 
+    /** Specifies failure mode. {@link FailureMode#FAIL_FAST} mode is selected 
by default. */
+    public Write withFailureMode(FailureMode failureMode) {
+      return toBuilder().setFailureMode(failureMode).build();
+    }
+
     @Override
-    public PDone expand(PCollection<Mutation> input) {
+    public SpannerWriteResult expand(PCollection<Mutation> input) {
       getSpannerConfig().validate();
 
-      input
+      return input
           .apply("To mutation group", ParDo.of(new ToMutationGroupFn()))
           .apply("Write mutations to Cloud Spanner", new WriteGrouped(this));
-      return PDone.in(input.getPipeline());
     }
 
     @Override
@@ -788,7 +807,8 @@ public void populateDisplayData(DisplayData.Builder 
builder) {
   }
 
   /** Same as {@link Write} but supports grouped mutations. */
-  public static class WriteGrouped extends 
PTransform<PCollection<MutationGroup>, PDone> {
+  public static class WriteGrouped
+      extends PTransform<PCollection<MutationGroup>, SpannerWriteResult> {
     private final Write spec;
 
     public WriteGrouped(Write spec) {
@@ -796,7 +816,8 @@ public WriteGrouped(Write spec) {
     }
 
     @Override
-    public PDone expand(PCollection<MutationGroup> input) {
+    public SpannerWriteResult expand(PCollection<MutationGroup> input) {
+
       PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, 
List<byte[]>>>>
           sampler = spec.getSampler();
       if (sampler == null) {
@@ -827,21 +848,23 @@ public PDone expand(PCollection<MutationGroup> input) {
               .apply("Sample keys", sampler)
               .apply("Keys sample as view", View.asMap());
 
+      TupleTag<Void> mainTag = new TupleTag<>("mainOut");
+      TupleTag<MutationGroup> failedTag = new TupleTag<>("failedMutations");
       // Assign partition based on the closest element in the sample and group 
mutations.
       AssignPartitionFn assignPartitionFn = new AssignPartitionFn(keySample);
-      serialized
+      PCollectionTuple result = serialized
           .apply("Partition input", 
ParDo.of(assignPartitionFn).withSideInputs(keySample))
           .setCoder(KvCoder.of(StringUtf8Coder.of(), 
SerializedMutationCoder.of()))
-          .apply("Group by partition", GroupByKey.create())
-          .apply(
-              "Batch mutations together",
+          .apply("Group by partition", GroupByKey.create()).apply("Batch 
mutations together",
               ParDo.of(new BatchFn(spec.getBatchSizeBytes(), 
spec.getSpannerConfig(), schemaView))
-                  .withSideInputs(schemaView))
-          .apply(
-              "Write mutations to Spanner",
-              ParDo.of(new WriteToSpannerFn(spec.getSpannerConfig())));
-      return PDone.in(input.getPipeline());
-
+                  .withSideInputs(schemaView)).apply("Write mutations to 
Spanner",
+              ParDo.of(new WriteToSpannerFn(spec.getSpannerConfig(), 
spec.getFailureMode(),
+                  failedTag))
+                  .withOutputTags(mainTag, TupleTagList.of(failedTag)));
+      PCollection<MutationGroup> failedMutations = result.get(failedTag);
+      failedMutations.setCoder(SerializableCoder.of(MutationGroup.class));
+      return new SpannerWriteResult(input.getPipeline(), result.get(mainTag), 
failedMutations,
+          failedTag);
     }
 
     private PTransform<PCollection<KV<String, byte[]>>, PCollection<KV<String, 
List<byte[]>>>>
@@ -948,11 +971,7 @@ public AssignPartitionFn(PCollectionView<Map<String, 
List<byte[]>>> sampleView)
    * Batches mutations together.
    */
   private static class BatchFn
-      extends DoFn<KV<String, Iterable<SerializedMutation>>, 
Iterable<Mutation>> {
-
-    private static final int MAX_RETRIES = 5;
-    private static final FluentBackoff BUNDLE_WRITE_BACKOFF = 
FluentBackoff.DEFAULT
-        
.withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5));
+      extends DoFn<KV<String, Iterable<SerializedMutation>>, 
Iterable<MutationGroup>> {
 
     private final long maxBatchSizeBytes;
     private final SpannerConfig spannerConfig;
@@ -960,7 +979,7 @@ public AssignPartitionFn(PCollectionView<Map<String, 
List<byte[]>>> sampleView)
 
     private transient SpannerAccessor spannerAccessor;
     // Current batch of mutations to be written.
-    private List<Mutation> mutations;
+    private List<MutationGroup> mutations;
     // total size of the current batch.
     private long batchSizeBytes;
 
@@ -972,14 +991,14 @@ private BatchFn(long maxBatchSizeBytes, SpannerConfig 
spannerConfig,
     }
 
     @Setup
-    public void setup() throws Exception {
+    public void setup() {
       mutations = new ArrayList<>();
       batchSizeBytes = 0;
       spannerAccessor = spannerConfig.connectToSpanner();
     }
 
     @Teardown
-    public void teardown() throws Exception {
+    public void teardown() {
       spannerAccessor.close();
     }
 
@@ -990,7 +1009,7 @@ public void processElement(ProcessContext c) throws 
Exception {
       for (SerializedMutation kv : element.getValue()) {
         byte[] value = kv.getMutationGroupBytes();
         MutationGroup mg = mutationGroupEncoder.decode(value);
-        Iterables.addAll(mutations, mg);
+        mutations.add(mg);
         batchSizeBytes += MutationSizeEstimator.sizeOf(mg);
         if (batchSizeBytes >= maxBatchSizeBytes || mutations.size() > 
MAX_NUM_MUTATIONS) {
           c.output(mutations);
@@ -1007,16 +1026,19 @@ public void processElement(ProcessContext c) throws 
Exception {
   }
 
   private static class WriteToSpannerFn
-      extends DoFn<Iterable<Mutation>, Void> {
-    private static final int MAX_RETRIES = 5;
-    private static final FluentBackoff BUNDLE_WRITE_BACKOFF = 
FluentBackoff.DEFAULT
-        
.withMaxRetries(MAX_RETRIES).withInitialBackoff(Duration.standardSeconds(5));
+      extends DoFn<Iterable<MutationGroup>, Void> {
 
     private transient SpannerAccessor spannerAccessor;
     private final SpannerConfig spannerConfig;
+    private final FailureMode failureMode;
+
+    private final TupleTag<MutationGroup> failedTag;
 
-    public WriteToSpannerFn(SpannerConfig spannerConfig) {
+    WriteToSpannerFn(SpannerConfig spannerConfig, FailureMode failureMode,
+        TupleTag<MutationGroup> failedTag) {
       this.spannerConfig = spannerConfig;
+      this.failureMode = failureMode;
+      this.failedTag = failedTag;
     }
 
     @Setup
@@ -1032,22 +1054,28 @@ public void teardown() throws Exception {
 
     @ProcessElement
     public void processElement(ProcessContext c) throws Exception {
-      Sleeper sleeper = Sleeper.DEFAULT;
-      BackOff backoff = BUNDLE_WRITE_BACKOFF.backoff();
-
-      Iterable<Mutation> mutations = c.element();
-
-      while (true) {
-        // Batch upsert rows.
-        try {
-          spannerAccessor.getDatabaseClient().writeAtLeastOnce(mutations);
-          // Break if the commit threw no exception.
-          break;
-        } catch (AbortedException exception) {
-          // Only log the code and message for potentially-transient errors. 
The entire exception
-          // will be propagated upon the last retry.
-          if (!BackOffUtils.next(sleeper, backoff)) {
-            throw exception;
+      Iterable<MutationGroup> mutations = c.element();
+      boolean tryIndividual = false;
+      // Batch upsert rows.
+      try {
+        Iterable<Mutation> batch = Iterables.concat(mutations);
+        spannerAccessor.getDatabaseClient().writeAtLeastOnce(batch);
+      } catch (SpannerException e) {
+        if (failureMode == FailureMode.REPORT_FAILURES) {
+          tryIndividual = true;
+        } else if (failureMode == FailureMode.FAIL_FAST) {
+          throw e;
+        } else {
+          throw new IllegalArgumentException("Unknown failure mode " + 
failureMode);
+        }
+      }
+      if (tryIndividual) {
+        for (MutationGroup mg : mutations) {
+          try {
+            spannerAccessor.getDatabaseClient().writeAtLeastOnce(mg);
+          } catch (SpannerException e) {
+            LOG.warn("Failed to submit the mutation group", e);
+            c.output(failedTag, mg);
           }
         }
       }
diff --git 
a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteResult.java
 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteResult.java
new file mode 100644
index 00000000000..416ab2a0c4d
--- /dev/null
+++ 
b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteResult.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.gcp.spanner;
+
+import com.google.common.collect.ImmutableMap;
+import java.util.Map;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+import org.apache.beam.sdk.values.PValue;
+import org.apache.beam.sdk.values.TupleTag;
+
+/**
+ * A result of {@link SpannerIO#write()} transform. Use {@link 
#getFailedMutations} to access
+ * failed Mutations. {@link #getOutput()} can be used as a completion signal 
with the
+ * {@link org.apache.beam.sdk.transforms.Wait} transform.
+ */
+public class SpannerWriteResult implements POutput {
+  private final Pipeline pipeline;
+  private final PCollection<Void> output;
+  private final PCollection<MutationGroup> failedMutations;
+  private final TupleTag<MutationGroup> failedMutationsTag;
+
+  public SpannerWriteResult(Pipeline pipeline, PCollection<Void> output,
+      PCollection<MutationGroup> failedMutations, TupleTag<MutationGroup> 
failedMutationsTag) {
+    this.pipeline = pipeline;
+    this.output = output;
+    this.failedMutations = failedMutations;
+    this.failedMutationsTag = failedMutationsTag;
+  }
+
+  @Override
+  public Pipeline getPipeline() {
+    return pipeline;
+  }
+
+  @Override
+  public Map<TupleTag<?>, PValue> expand() {
+    return ImmutableMap.of(failedMutationsTag, failedMutations);
+  }
+
+  public PCollection<MutationGroup> getFailedMutations() {
+    return failedMutations;
+  }
+
+  public PCollection<Void> getOutput() {
+    return output;
+  }
+
+  @Override
+  public void finishSpecifyingOutput(String transformName, PInput input,
+      PTransform<?, ?> transform) {
+
+  }
+}
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
index 61afb60b35a..a67adc67c1e 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIOWriteTest.java
@@ -19,19 +19,23 @@
 
 import static 
org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
 import static org.hamcrest.Matchers.hasSize;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.argThat;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import com.google.cloud.spanner.ErrorCode;
 import com.google.cloud.spanner.Key;
 import com.google.cloud.spanner.KeyRange;
 import com.google.cloud.spanner.KeySet;
 import com.google.cloud.spanner.Mutation;
 import com.google.cloud.spanner.ReadOnlyTransaction;
 import com.google.cloud.spanner.ResultSets;
+import com.google.cloud.spanner.SpannerExceptionFactory;
 import com.google.cloud.spanner.Statement;
 import com.google.cloud.spanner.Struct;
 import com.google.cloud.spanner.Type;
@@ -45,6 +49,7 @@
 import java.util.List;
 import java.util.Map;
 import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -350,6 +355,45 @@ public void batchingPlusSampling() throws Exception {
     );
   }
 
+  @Test
+  @Category(NeedsRunner.class)
+  public void reportFailures() throws Exception {
+    PCollection<MutationGroup> mutations = pipeline
+        .apply(Create.of(
+            g(m(1L)), g(m(2L)), g(m(3L)), g(m(4L)),  g(m(5L)),
+            g(m(6L)), g(m(7L)), g(m(8L)), g(m(9L)),  g(m(10L)))
+        );
+
+    when(serviceFactory.mockDatabaseClient().writeAtLeastOnce(any()))
+        .thenAnswer(invocationOnMock -> {
+          throw 
SpannerExceptionFactory.newSpannerException(ErrorCode.ALREADY_EXISTS, "oops");
+        });
+
+    SpannerWriteResult result = mutations.apply(SpannerIO.write()
+        .withProjectId("test-project")
+        .withInstanceId("test-instance")
+        .withDatabaseId("test-database")
+        .withServiceFactory(serviceFactory)
+        .withBatchSizeBytes(1000000000)
+        .withFailureMode(SpannerIO.FailureMode.REPORT_FAILURES)
+        .withSampler(fakeSampler(m(2L), m(5L), m(10L)))
+        .grouped());
+    PAssert.that(result.getFailedMutations()).satisfies(m -> {
+      assertEquals(10, Iterables.size(m));
+      return null;
+    });
+    pipeline.run();
+
+    verifyBatches(
+        batch(m(1L), m(2L)),
+        batch(m(3L), m(4L), m(5L)),
+        batch(m(6L), m(7L), m(8L), m(9L), m(10L)),
+        // Mutations were also retried individually.
+        batch(m(1L)), batch(m(2L)), batch(m(3L)), batch(m(4L)),
+        batch(m(5L)), batch(m(6L)), batch(m(7L)), batch(m(8L)),
+        batch(m(9L)), batch(m(10L)));
+  }
+
   @Test
   @Category(NeedsRunner.class)
   public void noBatchingPlusSampling() throws Exception {
diff --git 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
index 89be159db5c..7b813fca706 100644
--- 
a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
+++ 
b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerWriteIT.java
@@ -23,7 +23,6 @@
 
 import com.google.cloud.spanner.Database;
 import com.google.cloud.spanner.DatabaseAdminClient;
-import com.google.cloud.spanner.DatabaseClient;
 import com.google.cloud.spanner.DatabaseId;
 import com.google.cloud.spanner.Mutation;
 import com.google.cloud.spanner.Operation;
@@ -31,8 +30,14 @@
 import com.google.cloud.spanner.Spanner;
 import com.google.cloud.spanner.SpannerOptions;
 import com.google.cloud.spanner.Statement;
+import com.google.common.base.Predicate;
+import com.google.common.base.Predicates;
 import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
+import java.io.Serializable;
 import java.util.Collections;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.options.Default;
@@ -42,10 +47,13 @@
 import org.apache.beam.sdk.testing.TestPipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Wait;
+import org.hamcrest.Matchers;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.ExpectedException;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
@@ -56,11 +64,12 @@
   private static final int MAX_DB_NAME_LENGTH = 30;
 
   @Rule public final transient TestPipeline p = TestPipeline.create();
+  @Rule public transient ExpectedException thrown = ExpectedException.none();
 
   /** Pipeline options for this test. */
   public interface SpannerTestPipelineOptions extends TestPipelineOptions {
     @Description("Instance ID to write to in Spanner")
-    @Default.String("beam-test")
+    @Default.String("mairbek-deleteme")
     String getInstanceId();
     void setInstanceId(String value);
 
@@ -106,7 +115,7 @@ public void setUp() throws Exception {
                     + options.getTable()
                     + " ("
                     + "  Key           INT64,"
-                    + "  Value         STRING(MAX),"
+                    + "  Value         STRING(MAX) NOT NULL,"
                     + ") PRIMARY KEY (Key)"));
     op.waitFor();
   }
@@ -128,19 +137,72 @@ public void testWrite() throws Exception {
                 .withInstanceId(options.getInstanceId())
                 .withDatabaseId(databaseName));
 
-    p.run();
-    DatabaseClient databaseClient =
-        spanner.getDatabaseClient(
-            DatabaseId.of(
-                project, options.getInstanceId(), databaseName));
+    PipelineResult result = p.run();
+    result.waitUntilFinish();
+    assertThat(result.getState(), is(PipelineResult.State.DONE));
+    assertThat(countNumberOfRecords(), equalTo((long) numRecords));
+  }
 
-    ResultSet resultSet =
-        databaseClient
-            .singleUse()
-            .executeQuery(Statement.of("SELECT COUNT(*) FROM " + 
options.getTable()));
-    assertThat(resultSet.next(), is(true));
-    assertThat(resultSet.getLong(0), equalTo((long) numRecords));
-    assertThat(resultSet.next(), is(false));
+  @Test
+  public void testSequentialWrite() throws Exception {
+    int numRecords = 100;
+
+    SpannerWriteResult stepOne = p.apply("first step", 
GenerateSequence.from(0).to(numRecords))
+        .apply(ParDo.of(new GenerateMutations(options.getTable())))
+        .apply(
+            SpannerIO.write()
+                .withProjectId(project)
+                .withInstanceId(options.getInstanceId())
+                .withDatabaseId(databaseName));
+
+    p.apply("second step", GenerateSequence.from(numRecords).to(2 * 
numRecords))
+        .apply("Gen mutations", ParDo.of(new 
GenerateMutations(options.getTable())))
+        .apply(Wait.on(stepOne.getOutput()))
+        .apply("write to table2",
+            SpannerIO.write()
+                .withProjectId(project)
+                .withInstanceId(options.getInstanceId())
+                .withDatabaseId(databaseName));
+
+    PipelineResult result = p.run();
+    result.waitUntilFinish();
+    assertThat(result.getState(), is(PipelineResult.State.DONE));
+    assertThat(countNumberOfRecords(), equalTo(2L * numRecords));
+  }
+
+  @Test
+  public void testReportFailures() throws Exception {
+    int numRecords = 100;
+    p.apply(GenerateSequence.from(0).to(2 * numRecords))
+        .apply(ParDo.of(new GenerateMutations(options.getTable(), new 
DivBy2())))
+        .apply(
+            SpannerIO.write()
+                .withProjectId(project)
+                .withInstanceId(options.getInstanceId())
+                .withDatabaseId(databaseName)
+                .withFailureMode(SpannerIO.FailureMode.REPORT_FAILURES));
+
+    PipelineResult result = p.run();
+    result.waitUntilFinish();
+    assertThat(result.getState(), is(PipelineResult.State.DONE));
+    assertThat(countNumberOfRecords(), equalTo((long) numRecords));
+  }
+
+  @Test
+  public void testFailFast() throws Exception {
+    thrown.expect(Pipeline.PipelineExecutionException.class);
+    thrown.expectMessage(Matchers.containsString("Value must not be NULL in 
table users"));
+    int numRecords = 100;
+    p.apply(GenerateSequence.from(0).to(2 * numRecords))
+        .apply(ParDo.of(new GenerateMutations(options.getTable(), new 
DivBy2())))
+        .apply(
+            SpannerIO.write()
+                .withProjectId(project)
+                .withInstanceId(options.getInstanceId())
+                .withDatabaseId(databaseName));
+
+    PipelineResult result = p.run();
+    result.waitUntilFinish();
   }
 
   @After
@@ -152,9 +214,15 @@ public void tearDown() throws Exception {
   private static class GenerateMutations extends DoFn<Long, Mutation> {
     private final String table;
     private final int valueSize = 100;
+    private final Predicate<Long> injectError;
 
-    public GenerateMutations(String table) {
+    public GenerateMutations(String table, Predicate<Long> injectError) {
       this.table = table;
+      this.injectError = injectError;
+    }
+
+    public GenerateMutations(String table) {
+      this(table, Predicates.<Long>alwaysFalse());
     }
 
     @ProcessElement
@@ -162,9 +230,32 @@ public void processElement(ProcessContext c) {
       Mutation.WriteBuilder builder = Mutation.newInsertOrUpdateBuilder(table);
       Long key = c.element();
       builder.set("Key").to(key);
-      builder.set("Value").to(RandomUtils.randomAlphaNumeric(valueSize));
+      String value = injectError.apply(key) ? null : 
RandomUtils.randomAlphaNumeric(valueSize);
+      builder.set("Value").to(value);
       Mutation mutation = builder.build();
       c.output(mutation);
     }
   }
+
+  private long countNumberOfRecords() {
+    ResultSet resultSet =
+        spanner
+            .getDatabaseClient(DatabaseId.of(project, options.getInstanceId(), 
databaseName))
+            .singleUse()
+            .executeQuery(Statement.of("SELECT COUNT(*) FROM " + 
options.getTable()));
+    assertThat(resultSet.next(), is(true));
+    long result = resultSet.getLong(0);
+    assertThat(resultSet.next(), is(false));
+    return result;
+  }
+
+  private static class DivBy2 implements Predicate<Long>, Serializable {
+
+    @Override
+    public boolean apply(@Nullable Long input) {
+      return input % 2 == 0;
+    }
+  }
+
+
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 99741)
    Time Spent: 3h 10m  (was: 3h)

> Chaining SpannerIO#write() transforms
> -------------------------------------
>
>                 Key: BEAM-4061
>                 URL: https://issues.apache.org/jira/browse/BEAM-4061
>             Project: Beam
>          Issue Type: Bug
>          Components: io-java-gcp
>            Reporter: Mairbek Khadikov
>            Assignee: Mairbek Khadikov
>            Priority: Major
>             Fix For: 2.5.0
>
>          Time Spent: 3h 10m
>  Remaining Estimate: 0h
>
> It should be possible to chain several Cloud Spanner writes. In practice, we 
> can leverage Wait.on transform by returning a result object from 
> SpannerIO#write.
> One particular example, when this feature is useful is full database import. 
> When data in parent tables should be injected before interleaved tables. See 
> more about table hierarchies in Spanner here 
> https://cloud.google.com/spanner/docs/schema-and-data-model#creating_a_hierarchy_of_interleaved_tables



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to