[ 
https://issues.apache.org/jira/browse/BEAM-3506?focusedWorklogId=109805&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-109805
 ]

ASF GitHub Bot logged work on BEAM-3506:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 07/Jun/18 17:30
            Start Date: 07/Jun/18 17:30
    Worklog Time Spent: 10m 
      Work Description: jkff closed pull request #4457: [BEAM-3506] - Add a 
feature in JdbcIO that allows writing PCollection<Iterable<T>>
URL: https://github.com/apache/beam/pull/4457
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index b134ec02ddf..57301fecb9c 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -25,6 +25,7 @@
 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
+
 import javax.annotation.Nullable;
 import javax.sql.DataSource;
 import org.apache.beam.sdk.annotations.Experimental;
@@ -154,13 +155,23 @@
     return new AutoValue_JdbcIO_ReadAll.Builder<ParameterT, OutputT>().build();
   }
 
+  private static final long DEFAULT_BATCH_SIZE = 1000L;
+
   /**
    * Write data to a JDBC datasource.
    *
    * @param <T> Type of the data to be written.
    */
   public static <T> Write<T> write() {
-    return new AutoValue_JdbcIO_Write.Builder<T>().build();
+    return new AutoValue_JdbcIO_Write.Builder<T>()
+            .setBatchSize(DEFAULT_BATCH_SIZE)
+            .build();
+  }
+
+  public static <T> WriteIterable<T> writeIterable() {
+    return new AutoValue_JdbcIO_WriteIterable.Builder<T>()
+            .setBatchSize(DEFAULT_BATCH_SIZE)
+            .build();
   }
 
   private JdbcIO() {}
@@ -506,57 +517,38 @@ public void teardown() throws Exception {
     void setParameters(T element, PreparedStatement preparedStatement) throws 
Exception;
   }
 
-  /** A {@link PTransform} to write to a JDBC datasource. */
-  @AutoValue
-  public abstract static class Write<T> extends PTransform<PCollection<T>, 
PDone> {
+  /** Abstract base class for JdbcIO write operations. */
+  abstract static class AbstractWrite<RowT, InputT> extends 
PTransform<PCollection<InputT>, PDone> {
     @Nullable abstract DataSourceConfiguration getDataSourceConfiguration();
     @Nullable abstract String getStatement();
-    @Nullable abstract PreparedStatementSetter<T> getPreparedStatementSetter();
-
-    abstract Builder<T> toBuilder();
-
-    @AutoValue.Builder
-    abstract static class Builder<T> {
-      abstract Builder<T> setDataSourceConfiguration(DataSourceConfiguration 
config);
-      abstract Builder<T> setStatement(String statement);
-      abstract Builder<T> 
setPreparedStatementSetter(PreparedStatementSetter<T> setter);
-
-      abstract Write<T> build();
-    }
-
-    public Write<T> withDataSourceConfiguration(DataSourceConfiguration 
config) {
-      return toBuilder().setDataSourceConfiguration(config).build();
-    }
-    public Write<T> withStatement(String statement) {
-      return toBuilder().setStatement(statement).build();
-    }
-    public Write<T> withPreparedStatementSetter(PreparedStatementSetter<T> 
setter) {
-      return toBuilder().setPreparedStatementSetter(setter).build();
-    }
+    abstract long getBatchSize();
+    @Nullable abstract PreparedStatementSetter<RowT> 
getPreparedStatementSetter();
 
     @Override
-    public PDone expand(PCollection<T> input) {
+    public PDone expand(PCollection<InputT> input) {
       checkArgument(
           getDataSourceConfiguration() != null, "withDataSourceConfiguration() 
is required");
       checkArgument(getStatement() != null, "withStatement() is required");
       checkArgument(
           getPreparedStatementSetter() != null, "withPreparedStatementSetter() 
is required");
 
-      input.apply(ParDo.of(new WriteFn<T>(this)));
+      input.apply(ParDo.of(createWriteFn(this)));
       return PDone.in(input.getPipeline());
     }
 
-    private static class WriteFn<T> extends DoFn<T, Void> {
-      private static final int DEFAULT_BATCH_SIZE = 1000;
+    protected abstract AbstractWriteFn<RowT, InputT>
+            createWriteFn(AbstractWrite<RowT, InputT> write);
+
+    abstract static class AbstractWriteFn<RowT, InputT> extends DoFn<InputT, 
Void> {
+      protected final AbstractWrite<RowT, InputT> spec;
 
-      private final Write<T> spec;
+      protected transient DataSource dataSource;
+      protected transient Connection connection;
+      protected transient PreparedStatement preparedStatement;
 
-      private DataSource dataSource;
-      private Connection connection;
-      private PreparedStatement preparedStatement;
-      private int batchCount;
+      protected transient int batchCount;
 
-      public WriteFn(Write<T> spec) {
+      public AbstractWriteFn(AbstractWrite<RowT, InputT> spec) {
         this.spec = spec;
       }
 
@@ -574,26 +566,14 @@ public void startBundle() {
       }
 
       @ProcessElement
-      public void processElement(ProcessContext context) throws Exception {
-        T record = context.element();
-
-        preparedStatement.clearParameters();
-        spec.getPreparedStatementSetter().setParameters(record, 
preparedStatement);
-        preparedStatement.addBatch();
-
-        batchCount++;
-
-        if (batchCount >= DEFAULT_BATCH_SIZE) {
-          executeBatch();
-        }
-      }
+      public abstract void processElement(ProcessContext context) throws 
Exception;
 
       @FinishBundle
       public void finishBundle() throws Exception {
         executeBatch();
       }
 
-      private void executeBatch() throws SQLException {
+      protected void executeBatch() throws SQLException {
         if (batchCount > 0) {
           preparedStatement.executeBatch();
           connection.commit();
@@ -603,6 +583,7 @@ private void executeBatch() throws SQLException {
 
       @Teardown
       public void teardown() throws Exception {
+        executeBatch();
         try {
           if (preparedStatement != null) {
             preparedStatement.close();
@@ -619,6 +600,150 @@ public void teardown() throws Exception {
     }
   }
 
+  /** A {@link PTransform} to write to a JDBC datasource. */
+  @AutoValue
+  public abstract static class Write<T> extends AbstractWrite<T, T> {
+    // Override is needed, otherwise the code generator complains about the 
type parameter
+    @Nullable abstract PreparedStatementSetter<T> getPreparedStatementSetter();
+
+    abstract Builder<T> toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder<T> {
+      abstract Builder<T> setDataSourceConfiguration(DataSourceConfiguration 
config);
+      abstract Builder<T> setStatement(String statement);
+      abstract Builder<T> setBatchSize(long batchSize);
+      abstract Builder<T> 
setPreparedStatementSetter(PreparedStatementSetter<T> setter);
+
+      abstract Write<T> build();
+    }
+
+    public Write<T> withDataSourceConfiguration(DataSourceConfiguration 
config) {
+      return toBuilder().setDataSourceConfiguration(config).build();
+    }
+    public Write<T> withStatement(String statement) {
+      return toBuilder().setStatement(statement).build();
+    }
+    public Write<T> withPreparedStatementSetter(PreparedStatementSetter<T> 
setter) {
+      return toBuilder().setPreparedStatementSetter(setter).build();
+    }
+
+    /**
+     * Provide a maximum size in number of SQL statement for the batch. 
Default is 1000.
+     *
+     * @param batchSize maximum batch size in number of statements
+     * @return the {@link Write} with connection batch size set
+     */
+    public Write<T> withBatchSize(long batchSize) {
+      checkArgument(batchSize > 0, "batchSize must be > 0, but was %d", 
batchSize);
+      return toBuilder().setBatchSize(batchSize).build();
+    }
+
+    @Override
+    protected WriteFn<T> createWriteFn(AbstractWrite<T, T> write) {
+      return new WriteFn<T>(write);
+    }
+
+    private static class WriteFn<T> extends AbstractWriteFn<T, T> {
+
+      public WriteFn(AbstractWrite<T, T> spec) {
+        super(spec);
+      }
+
+      @ProcessElement
+      @Override
+      public void processElement(ProcessContext context) throws Exception {
+        T record = context.element();
+
+        preparedStatement.clearParameters();
+        spec.getPreparedStatementSetter().setParameters(record, 
preparedStatement);
+        preparedStatement.addBatch();
+
+        batchCount++;
+
+        if (batchCount >= spec.getBatchSize()) {
+          executeBatch();
+        }
+      }
+    }
+  }
+
+  /**
+   * A {@link PTransform} to write to a JDBC datasource where the input is an
+   * Iterable of the row type.
+   */
+  @AutoValue
+  public abstract static class WriteIterable<T> extends AbstractWrite<T, 
Iterable<T>> {
+    // Override is needed, otherwise the code generator complains about the 
type parameter
+    @Nullable abstract PreparedStatementSetter<T> getPreparedStatementSetter();
+
+    abstract Builder<T> toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder<T> {
+      abstract Builder<T> setDataSourceConfiguration(DataSourceConfiguration 
config);
+      abstract Builder<T> setStatement(String statement);
+      abstract Builder<T> setBatchSize(long batchSize);
+      abstract Builder<T> 
setPreparedStatementSetter(PreparedStatementSetter<T> setter);
+
+      abstract WriteIterable<T> build();
+    }
+
+    public WriteIterable<T> 
withDataSourceConfiguration(DataSourceConfiguration config) {
+      return toBuilder().setDataSourceConfiguration(config).build();
+    }
+    public WriteIterable<T> withStatement(String statement) {
+      return toBuilder().setStatement(statement).build();
+    }
+    public WriteIterable<T> 
withPreparedStatementSetter(PreparedStatementSetter<T> setter) {
+      return toBuilder().setPreparedStatementSetter(setter).build();
+    }
+
+    /**
+     * Provide a maximum size in number of statements for the batch. If the 
number of
+     * records in an Iterable exceeds this number, a commit() will be executed 
for every
+     * batchSize records.
+     * Default is 1000.
+     *
+     * @param batchSize maximum batch size in number of statements
+     * @return the {@link Write} with connection batch size set
+     */
+    public WriteIterable<T> withBatchSize(long batchSize) {
+      checkArgument(batchSize > 0, "batchSize must be > 0, but was %d", 
batchSize);
+      return toBuilder().setBatchSize(batchSize).build();
+    }
+
+    @Override
+    protected WriteIterableFn<T> createWriteFn(AbstractWrite<T, Iterable<T>> 
write) {
+      return new WriteIterableFn<T>(write);
+    }
+
+    private static class WriteIterableFn<T> extends AbstractWriteFn<T, 
Iterable<T>> {
+
+      public WriteIterableFn(AbstractWrite<T, Iterable<T>> spec) {
+        super(spec);
+      }
+
+      @ProcessElement
+      @Override
+      public void processElement(ProcessContext context) throws Exception {
+        Iterable<T> records = context.element();
+
+        for (T record : records) {
+          preparedStatement.clearParameters();
+          spec.getPreparedStatementSetter().setParameters(record, 
preparedStatement);
+          preparedStatement.addBatch();
+
+          batchCount++;
+
+          if (batchCount >= spec.getBatchSize()) {
+            executeBatch();
+          }
+        }
+      }
+    }
+  }
+
   private static class Reparallelize<T> extends PTransform<PCollection<T>, 
PCollection<T>> {
     @Override
     public PCollection<T> expand(PCollection<T> input) {
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
index ed169c72254..270bed66289 100644
--- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
+++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.sdk.io.jdbc;
 
+import java.io.Serializable;
 import java.sql.SQLException;
 import java.util.List;
 import org.apache.beam.sdk.coders.SerializableCoder;
@@ -30,8 +31,11 @@
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Top;
+import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -65,7 +69,8 @@
 
   private static int numberOfRows;
   private static PGSimpleDataSource dataSource;
-  private static String tableName;
+  private static String writeTableName;
+  private static String writeIterableTableName;
 
   @Rule
   public TestPipeline pipelineWrite = TestPipeline.create();
@@ -80,13 +85,16 @@ public static void setup() throws SQLException {
 
     numberOfRows = options.getNumberOfRecords();
     dataSource = DatabaseTestHelper.getPostgresDataSource(options);
-    tableName = DatabaseTestHelper.getTestTableName("IT");
-    DatabaseTestHelper.createTable(dataSource, tableName);
+    writeTableName = DatabaseTestHelper.getTestTableName("IT");
+    writeIterableTableName = 
DatabaseTestHelper.getTestTableName("IT_ITERABLE");
+    DatabaseTestHelper.createTable(dataSource, writeTableName);
+    DatabaseTestHelper.createTable(dataSource, writeIterableTableName);
   }
 
   @AfterClass
   public static void tearDown() throws SQLException {
-    DatabaseTestHelper.deleteTable(dataSource, tableName);
+    DatabaseTestHelper.deleteTable(dataSource, writeTableName);
+    DatabaseTestHelper.deleteTable(dataSource, writeIterableTableName);
   }
 
   /**
@@ -95,7 +103,16 @@ public static void tearDown() throws SQLException {
   @Test
   public void testWriteThenRead() {
     runWrite();
-    runRead();
+    runRead(writeTableName);
+  }
+
+  /**
+   * Tests writing iterables then reading data for a postgres database.
+   */
+  @Test
+  public void testWriteIterableThenRead() {
+    runWriteIterable();
+    runRead(writeIterableTableName);
   }
 
   /**
@@ -111,7 +128,37 @@ private void runWrite() {
         .apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn()))
         .apply(JdbcIO.<TestRow>write()
             
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
-            .withStatement(String.format("insert into %s values(?, ?)", 
tableName))
+            .withStatement(String.format("insert into %s values(?, ?)", 
writeTableName))
+            .withPreparedStatementSetter(new 
JdbcTestHelper.PrepareStatementFromTestRow()));
+
+    pipelineWrite.run().waitUntilFinish();
+  }
+
+  static class ModuloTen
+        extends DoFn<TestRow, KV<Integer, TestRow>> {
+    @ProcessElement
+    public void processElement(ProcessContext context) throws Exception {
+      context.output(KV.of(context.element().id() % 10, context.element()));
+    }
+  }
+  private static class GetValues
+        extends DoFn<KV<Integer, Iterable<TestRow>>, Iterable<TestRow>>
+        implements Serializable {
+    @ProcessElement
+    public void processElement(ProcessContext context) throws Exception {
+      context.output(context.element().getValue());
+    }
+  }
+
+  private void runWriteIterable() {
+    pipelineWrite.apply(GenerateSequence.from(0).to(numberOfRows))
+        .apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn()))
+        .apply(ParDo.of(new ModuloTen()))
+        .apply(GroupByKey.<Integer, TestRow> create())
+        .apply(ParDo.of(new GetValues()))
+        .apply(JdbcIO.<TestRow>writeIterable()
+            
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
+            .withStatement(String.format("insert into %s values(?, ?)", 
writeIterableTableName))
             .withPreparedStatementSetter(new 
JdbcTestHelper.PrepareStatementFromTestRow()));
 
     pipelineWrite.run().waitUntilFinish();
@@ -134,8 +181,10 @@ private void runWrite() {
    * 2. Use containsInAnyOrder to verify that their values are correct.
    * Where first/last 500 rows is determined by the fact that we know all rows 
have a unique id - we
    * can use the natural ordering of that key.
+   *
+   * @param tableName The table to read and verify
    */
-  private void runRead() {
+  private void runRead(String tableName) {
     PCollection<TestRow> namesAndIds =
         pipelineRead.apply(JdbcIO.<TestRow>read()
         
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index 4871f20a58e..d2369e9fe06 100644
--- 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++ 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -26,11 +26,10 @@
 import java.net.ServerSocket;
 import java.sql.Connection;
 import java.sql.PreparedStatement;
-import java.sql.ResultSet;
 import java.sql.SQLException;
-import java.sql.Statement;
 import java.util.ArrayList;
 import java.util.Collections;
+
 import javax.sql.DataSource;
 
 import org.apache.beam.sdk.coders.KvCoder;
@@ -43,12 +42,14 @@
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.derby.drda.NetworkServerControl;
 import org.apache.derby.jdbc.ClientDataSource;
 import org.junit.AfterClass;
-import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
@@ -67,6 +68,8 @@
 
   private static int port;
   private static String readTableName;
+  private static String writeTableName;
+  private static String writeIterableTableName;
 
   @Rule
   public final transient TestPipeline pipeline = TestPipeline.create();
@@ -113,8 +116,12 @@ public static void startDatabase() throws Exception {
     dataSource.setPortNumber(port);
 
     readTableName = DatabaseTestHelper.getTestTableName("UT_READ");
+    writeTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+    writeIterableTableName = 
DatabaseTestHelper.getTestTableName("UT_WRITE_ITERABLE");
 
     DatabaseTestHelper.createTable(dataSource, readTableName);
+    DatabaseTestHelper.createTable(dataSource, writeTableName);
+    DatabaseTestHelper.createTable(dataSource, writeIterableTableName);
     addInitialData(dataSource, readTableName);
   }
 
@@ -122,6 +129,8 @@ public static void startDatabase() throws Exception {
   public static void shutDownDatabase() throws Exception {
     try {
       DatabaseTestHelper.deleteTable(dataSource, readTableName);
+      DatabaseTestHelper.deleteTable(dataSource, writeTableName);
+      DatabaseTestHelper.deleteTable(dataSource, writeIterableTableName);
     } finally {
       if (derbyServer != null) {
         derbyServer.shutdown();
@@ -252,47 +261,31 @@ public void setParameters(PreparedStatement 
preparedStatement)
 
   @Test
   public void testWrite() throws Exception {
-    final long rowsToAdd = 1000L;
+    final int rowsToAdd = EXPECTED_ROW_COUNT;
+    //String sql = "merge into %s as dest \n"
+    //        + "using %s as source \n"
+    //        + "on (dest.id=?) \n"
+    //        + "when not matched then insert values (?, ?)";
+    // This insert statement should be a merge statement, but that generates a
+    // NullPointerException in the Derby server
+    String sql = "insert into %s values(?, ?)";
+
+    pipeline.apply(Create.of(createTestDataForWrite(rowsToAdd)))
+        .apply(JdbcIO.<KV<Integer, String>>write()
+            
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
+            .withStatement(String.format(sql, writeTableName))
+            .withBatchSize(1000L)
+            .withPreparedStatementSetter(
+                new JdbcIO.PreparedStatementSetter<KV<Integer, String>>() {
+              public void setParameters(
+                  KV<Integer, String> element, PreparedStatement statement) 
throws Exception {
+                statement.setInt(1, element.getKey());
+                statement.setString(2, element.getValue());
+              }
+            }));
 
-    String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
-    DatabaseTestHelper.createTable(dataSource, tableName);
-    try {
-      ArrayList<KV<Integer, String>> data = new ArrayList<>();
-      for (int i = 0; i < rowsToAdd; i++) {
-        KV<Integer, String> kv = KV.of(i, "Test");
-        data.add(kv);
-      }
-      pipeline.apply(Create.of(data))
-          .apply(JdbcIO.<KV<Integer, String>>write()
-              
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(
-                  "org.apache.derby.jdbc.ClientDriver",
-                  "jdbc:derby://localhost:" + port + "/target/beam"))
-              .withStatement(String.format("insert into %s values(?, ?)", 
tableName))
-              .withPreparedStatementSetter(
-                  new JdbcIO.PreparedStatementSetter<KV<Integer, String>>() {
-                public void setParameters(
-                    KV<Integer, String> element, PreparedStatement statement) 
throws Exception {
-                  statement.setInt(1, element.getKey());
-                  statement.setString(2, element.getValue());
-                }
-              }));
-
-      pipeline.run();
-
-      try (Connection connection = dataSource.getConnection()) {
-        try (Statement statement = connection.createStatement()) {
-          try (ResultSet resultSet = statement.executeQuery("select count(*) 
from "
-                + tableName)) {
-            resultSet.next();
-            int count = resultSet.getInt(1);
-
-            Assert.assertEquals(EXPECTED_ROW_COUNT, count);
-          }
-        }
-      }
-    } finally {
-      DatabaseTestHelper.deleteTable(dataSource, tableName);
-    }
+    pipeline.run();
+    verifyWriteResults(writeTableName);
   }
 
   @Test
@@ -300,9 +293,7 @@ public void testWriteWithEmptyPCollection() throws 
Exception {
     pipeline
         .apply(Create.empty(KvCoder.of(VarIntCoder.of(), 
StringUtf8Coder.of())))
         .apply(JdbcIO.<KV<Integer, String>>write()
-            .withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(
-                "org.apache.derby.jdbc.ClientDriver",
-                "jdbc:derby://localhost:" + port + "/target/beam"))
+            
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
             .withStatement("insert into BEAM values(?, ?)")
             .withPreparedStatementSetter(new 
JdbcIO.PreparedStatementSetter<KV<Integer, String>>() {
               public void setParameters(KV<Integer, String> element, 
PreparedStatement statement)
@@ -314,4 +305,77 @@ public void setParameters(KV<Integer, String> element, 
PreparedStatement stateme
 
     pipeline.run();
   }
+
+  private static class ModuloTen
+        extends DoFn<KV<Integer, String>, KV<Integer, KV<Integer, String>>> {
+    @ProcessElement
+    public void processElement(ProcessContext context) throws Exception {
+      context.output(KV.of(context.element().getKey() % 10, 
context.element()));
+    }
+  }
+
+  private static class GetValues
+        extends DoFn<KV<Integer, Iterable<KV<Integer, String>>>, 
Iterable<KV<Integer, String>>>
+        implements Serializable {
+    @ProcessElement
+    public void processElement(ProcessContext context) throws Exception {
+      context.output(context.element().getValue());
+    }
+  }
+
+  @Test
+  public void testWriteIterable() throws Exception {
+    final int rowsToAdd = EXPECTED_ROW_COUNT;
+    // This insert statement should be a merge statement, but that generates 
an error
+    // in the Derby driver
+    String sql = "insert into %s values(?, ?)";
+
+    pipeline.apply(Create.of(createTestDataForWrite(rowsToAdd)))
+        .apply(ParDo.of(new ModuloTen()))
+        .apply(GroupByKey.<Integer, KV<Integer, String>> create())
+        .apply(ParDo.of(new GetValues()))
+        .apply(JdbcIO.<KV<Integer, String>>writeIterable()
+            
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
+            .withStatement(String.format(sql, writeIterableTableName))
+            .withPreparedStatementSetter(
+                new JdbcIO.PreparedStatementSetter<KV<Integer, String>>() {
+              public void setParameters(
+                  KV<Integer, String> element, PreparedStatement statement) 
throws Exception {
+                statement.setInt(1, element.getKey());
+                statement.setString(2, element.getValue());
+              }
+            }));
+
+    pipeline.run();
+    verifyWriteResults(writeIterableTableName);
+  }
+
+  private ArrayList<KV<Integer, String>> createTestDataForWrite(int rowsToAdd) 
{
+    ArrayList<KV<Integer, String>> data = new ArrayList<>();
+    for (int i = 0; i < rowsToAdd; i++) {
+      KV<Integer, String> kv = KV.of(i, "Testval" + i);
+      data.add(kv);
+    }
+    return data;
+  }
+
+  private void verifyWriteResults(String tableName) {
+    PCollection<TestRow> rows = pipeline.apply(
+          JdbcIO.<TestRow>read()
+              
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
+              // Distinct as merge statements do not work on the integrated 
Derby db
+              // and it is therefore possible that rows have been inserted 
multiple times
+              .withQuery("select distinct name,id from " + tableName)
+              .withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId())
+              .withCoder(SerializableCoder.of(TestRow.class)));
+
+    PAssert.thatSingleton(
+          rows.apply("Count All", Count.<TestRow>globally()))
+          .isEqualTo((long) EXPECTED_ROW_COUNT);
+
+    Iterable<TestRow> expectedValues = TestRow.getExpectedValues(0, 
EXPECTED_ROW_COUNT);
+    PAssert.that(rows).containsInAnyOrder(expectedValues);
+
+    pipeline.run();
+  }
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

            Worklog Id:     (was: 109805)
            Time Spent: 0.5h  (was: 20m)
    Remaining Estimate: 3.5h  (was: 3h 40m)

> JdbcIO: Support writing iterables (i.e. collections) of rows instead of only 
> single rows
> ----------------------------------------------------------------------------------------
>
>                 Key: BEAM-3506
>                 URL: https://issues.apache.org/jira/browse/BEAM-3506
>             Project: Beam
>          Issue Type: Improvement
>          Components: z-do-not-use-sdk-java-extensions
>    Affects Versions: 2.3.0
>            Reporter: Knut Olav Loite
>            Assignee: Jean-Baptiste Onofré
>            Priority: Minor
>              Labels: JdbcIO, jdbc
>             Fix For: Not applicable
>
>   Original Estimate: 4h
>          Time Spent: 0.5h
>  Remaining Estimate: 3.5h
>
> The current JdbcIO write interface expects a PCollection<T> where T is the 
> row to be written. Each instance of T is then added to a batch and written to 
> the database. The user has little control over how many rows will be added to 
> one batch. If JdbcIO would also support writing a PCollection<Iterable<T>> 
> the user would have more control over the number of rows in one batch. 
> Especially when writing to cloud databases, such as Google Cloud Spanner, the 
> batching of multiple rows together is important for performance.
> I already have a solution locally and I will submit a pull request.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to