This is an automated email from the ASF dual-hosted git repository.

pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e18e2b9  Merge pull request #16579 from Revert "Revert "Merge pull 
request #15863 from [BEAM-13184] Autoshard…
e18e2b9 is described below

commit e18e2b9ba26bee41e93c3109318672eb7b1b26bc
Author: Pablo <[email protected]>
AuthorDate: Thu Jan 27 19:41:53 2022 -0800

    Merge pull request #16579 from Revert "Revert "Merge pull request #15863 
from [BEAM-13184] Autoshard…
    
    * Revert "Revert "Merge pull request #15863 from [BEAM-13184] Autosharding 
for JdbcIO.write* transforms""
    
    This reverts commit 421bc8068fc561a358cfbf6c9842408672872120.
    
    * Using batchSize to define element batch size
    
    * Handle corner case for null list
---
 .../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java   | 112 +++++++++++++++++----
 .../java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java |  39 +++++++
 .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java    |  30 ++++++
 3 files changed, 164 insertions(+), 17 deletions(-)

diff --git 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 0f6b9c3..14f5e69 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -66,15 +66,19 @@ import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Filter;
 import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.GroupIntoBatches;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Reshuffle;
 import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.Values;
 import org.apache.beam.sdk.transforms.View;
 import org.apache.beam.sdk.transforms.Wait;
+import org.apache.beam.sdk.transforms.WithKeys;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
 import org.apache.beam.sdk.util.BackOff;
 import org.apache.beam.sdk.util.BackOffUtils;
 import org.apache.beam.sdk.util.FluentBackoff;
@@ -82,6 +86,7 @@ import org.apache.beam.sdk.util.Sleeper;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.sdk.values.Row;
@@ -96,6 +101,7 @@ import org.apache.commons.pool2.impl.GenericObjectPool;
 import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
 import org.checkerframework.checker.nullness.qual.Nullable;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -1318,6 +1324,11 @@ public class JdbcIO {
       this.inner = inner;
     }
 
+    /** See {@link WriteVoid#withAutoSharding()}. */
+    public Write<T> withAutoSharding() {
+      return new Write<>(inner.withAutoSharding());
+    }
+
     /** See {@link 
WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */
     public Write<T> withDataSourceConfiguration(DataSourceConfiguration 
config) {
       return new Write<>(inner.withDataSourceConfiguration(config));
@@ -1393,6 +1404,7 @@ public class JdbcIO {
           .setPreparedStatementSetter(inner.getPreparedStatementSetter())
           .setStatement(inner.getStatement())
           .setTable(inner.getTable())
+          .setAutoSharding(inner.getAutoSharding())
           .build();
     }
 
@@ -1408,6 +1420,51 @@ public class JdbcIO {
     }
   }
 
+  /* The maximum number of elements that will be included in a batch. */
+
+  static <T> PCollection<Iterable<T>> batchElements(
+      PCollection<T> input, Boolean withAutoSharding, long batchSize) {
+    PCollection<Iterable<T>> iterables;
+    if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null 
&& withAutoSharding) {
+      iterables =
+          input
+              .apply(WithKeys.<String, T>of(""))
+              .apply(
+                  GroupIntoBatches.<String, T>ofSize(batchSize)
+                      .withMaxBufferingDuration(Duration.millis(200))
+                      .withShardedKey())
+              .apply(Values.create());
+    } else {
+      iterables =
+          input.apply(
+              ParDo.of(
+                  new DoFn<T, Iterable<T>>() {
+                    List<T> outputList;
+
+                    @ProcessElement
+                    public void process(ProcessContext c) {
+                      if (outputList == null) {
+                        outputList = new ArrayList<>();
+                      }
+                      outputList.add(c.element());
+                      if (outputList.size() > batchSize) {
+                        c.output(outputList);
+                        outputList = null;
+                      }
+                    }
+
+                    @FinishBundle
+                    public void finish(FinishBundleContext c) {
+                      if (outputList != null && outputList.size() > 0) {
+                        c.output(outputList, Instant.now(), 
GlobalWindow.INSTANCE);
+                      }
+                      outputList = null;
+                    }
+                  }));
+    }
+    return iterables;
+  }
+
   /** Interface implemented by functions that sets prepared statement data. */
   @FunctionalInterface
   interface PreparedStatementSetCaller extends Serializable {
@@ -1430,6 +1487,8 @@ public class JdbcIO {
   @AutoValue
   public abstract static class WriteWithResults<T, V extends JdbcWriteResult>
       extends PTransform<PCollection<T>, PCollection<V>> {
+    abstract @Nullable Boolean getAutoSharding();
+
     abstract @Nullable SerializableFunction<Void, DataSource> 
getDataSourceProviderFn();
 
     abstract @Nullable ValueProvider<String> getStatement();
@@ -1451,6 +1510,8 @@ public class JdbcIO {
       abstract Builder<T, V> setDataSourceProviderFn(
           SerializableFunction<Void, DataSource> dataSourceProviderFn);
 
+      abstract Builder<T, V> setAutoSharding(Boolean autoSharding);
+
       abstract Builder<T, V> setStatement(ValueProvider<String> statement);
 
       abstract Builder<T, V> 
setPreparedStatementSetter(PreparedStatementSetter<T> setter);
@@ -1487,6 +1548,11 @@ public class JdbcIO {
       return toBuilder().setPreparedStatementSetter(setter).build();
     }
 
+    /** If true, enables using a dynamically determined number of shards to 
write. */
+    public WriteWithResults<T, V> withAutoSharding() {
+      return toBuilder().setAutoSharding(true).build();
+    }
+
     /**
      * When a SQL exception occurs, {@link Write} uses this {@link 
RetryStrategy} to determine if it
      * will retry the statements. If {@link RetryStrategy#apply(SQLException)} 
returns {@code true},
@@ -1549,8 +1615,15 @@ public class JdbcIO {
       checkArgument(
           (getDataSourceProviderFn() != null),
           "withDataSourceConfiguration() or withDataSourceProviderFn() is 
required");
-
-      return input.apply(
+      checkArgument(
+          getAutoSharding() == null
+              || (getAutoSharding() && input.isBounded() != 
IsBounded.UNBOUNDED),
+          "Autosharding is only supported for streaming pipelines.");
+      ;
+
+      PCollection<Iterable<T>> iterables =
+          JdbcIO.<T>batchElements(input, getAutoSharding(), 
DEFAULT_BATCH_SIZE);
+      return iterables.apply(
           ParDo.of(
               new WriteFn<T, V>(
                   WriteFnSpec.builder()
@@ -1573,6 +1646,8 @@ public class JdbcIO {
   @AutoValue
   public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, 
PCollection<Void>> {
 
+    abstract @Nullable Boolean getAutoSharding();
+
     abstract @Nullable SerializableFunction<Void, DataSource> 
getDataSourceProviderFn();
 
     abstract @Nullable ValueProvider<String> getStatement();
@@ -1591,6 +1666,8 @@ public class JdbcIO {
 
     @AutoValue.Builder
     abstract static class Builder<T> {
+      abstract Builder<T> setAutoSharding(Boolean autoSharding);
+
       abstract Builder<T> setDataSourceProviderFn(
           SerializableFunction<Void, DataSource> dataSourceProviderFn);
 
@@ -1609,6 +1686,11 @@ public class JdbcIO {
       abstract WriteVoid<T> build();
     }
 
+    /** If true, enables using a dynamically determined number of shards to 
write. */
+    public WriteVoid<T> withAutoSharding() {
+      return toBuilder().setAutoSharding(true).build();
+    }
+
     public WriteVoid<T> withDataSourceConfiguration(DataSourceConfiguration 
config) {
       return withDataSourceProviderFn(new 
DataSourceProviderFromDataSourceConfiguration(config));
     }
@@ -1708,7 +1790,11 @@ public class JdbcIO {
         checkArgument(
             spec.getPreparedStatementSetter() != null, 
"withPreparedStatementSetter() is required");
       }
-      return input
+
+      PCollection<Iterable<T>> iterables =
+          JdbcIO.<T>batchElements(input, getAutoSharding(), getBatchSize());
+
+      return iterables
           .apply(
               ParDo.of(
                   new WriteFn<T, Void>(
@@ -1955,7 +2041,7 @@ public class JdbcIO {
    * @param <T>
    * @param <V>
    */
-  static class WriteFn<T, V> extends DoFn<T, V> {
+  static class WriteFn<T, V> extends DoFn<Iterable<T>, V> {
 
     @AutoValue
     abstract static class WriteFnSpec<T, V> implements Serializable, 
HasDisplayData {
@@ -2045,7 +2131,6 @@ public class JdbcIO {
     private Connection connection;
     private PreparedStatement preparedStatement;
     private static FluentBackoff retryBackOff;
-    private final List<T> records = new ArrayList<>();
 
     public WriteFn(WriteFnSpec<T, V> spec) {
       this.spec = spec;
@@ -2085,17 +2170,12 @@ public class JdbcIO {
 
     @ProcessElement
     public void processElement(ProcessContext context) throws Exception {
-      T record = context.element();
-      records.add(record);
-      if (records.size() >= spec.getBatchSize()) {
-        executeBatch(context);
-      }
+      executeBatch(context, context.element());
     }
 
     @FinishBundle
     public void finishBundle() throws Exception {
       // We pass a null context because we only execute a final batch for 
WriteVoid cases.
-      executeBatch(null);
       cleanUpStatementAndConnection();
     }
 
@@ -2124,11 +2204,8 @@ public class JdbcIO {
       }
     }
 
-    private void executeBatch(ProcessContext context)
+    private void executeBatch(ProcessContext context, Iterable<T> records)
         throws SQLException, IOException, InterruptedException {
-      if (records.isEmpty()) {
-        return;
-      }
       Long startTimeNs = System.nanoTime();
       Sleeper sleeper = Sleeper.DEFAULT;
       BackOff backoff = retryBackOff.backoff();
@@ -2137,8 +2214,10 @@ public class JdbcIO {
             getConnection().prepareStatement(spec.getStatement().get())) {
           try {
             // add each record in the statement batch
+            int recordsInBatch = 0;
             for (T record : records) {
               processRecord(record, preparedStatement, context);
+              recordsInBatch += 1;
             }
             if (!spec.getReturnResults()) {
               // execute the batch
@@ -2146,7 +2225,7 @@ public class JdbcIO {
               // commit the changes
               getConnection().commit();
             }
-            RECORDS_PER_BATCH.update(records.size());
+            RECORDS_PER_BATCH.update(recordsInBatch);
             
MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
startTimeNs));
             break;
           } catch (SQLException exception) {
@@ -2164,7 +2243,6 @@ public class JdbcIO {
           }
         }
       }
-      records.clear();
     }
 
     private void processRecord(T record, PreparedStatement preparedStatement, 
ProcessContext c) {
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
index 8cebbbd..59bc764 100644
--- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
+++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
@@ -32,6 +32,9 @@ import java.util.Set;
 import java.util.UUID;
 import java.util.function.Function;
 import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.io.GenerateSequence;
 import org.apache.beam.sdk.io.common.DatabaseTestHelper;
 import org.apache.beam.sdk.io.common.HashingFn;
@@ -39,6 +42,7 @@ import 
org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
 import org.apache.beam.sdk.io.common.TestRow;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.testutils.NamedTestResult;
 import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
 import org.apache.beam.sdk.testutils.metrics.MetricsReader;
@@ -51,6 +55,7 @@ import org.apache.beam.sdk.transforms.ParDo;
 import org.apache.beam.sdk.transforms.Top;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Instant;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
@@ -255,6 +260,40 @@ public class JdbcIOIT {
   }
 
   @Test
+  public void testWriteWithAutosharding() throws Exception {
+    String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+    DatabaseTestHelper.createTable(dataSource, firstTableName);
+    try {
+      List<KV<Integer, String>> data = getTestDataToWrite(EXPECTED_ROW_COUNT);
+      TestStream.Builder<KV<Integer, String>> ts =
+          TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
+              .advanceWatermarkTo(Instant.now());
+      for (KV<Integer, String> elm : data) {
+        ts.addElements(elm);
+      }
+
+      PCollection<KV<Integer, String>> dataCollection =
+          pipelineWrite.apply(ts.advanceWatermarkToInfinity());
+      dataCollection.apply(
+          JdbcIO.<KV<Integer, String>>write()
+              .withDataSourceProviderFn(voidInput -> dataSource)
+              .withStatement(String.format("insert into %s values(?, ?) 
returning *", tableName))
+              .withAutoSharding()
+              .withPreparedStatementSetter(
+                  (element, statement) -> {
+                    statement.setInt(1, element.getKey());
+                    statement.setString(2, element.getValue());
+                  }));
+
+      pipelineWrite.run().waitUntilFinish();
+
+      runRead();
+    } finally {
+      DatabaseTestHelper.deleteTable(dataSource, firstTableName);
+    }
+  }
+
+  @Test
   public void testWriteWithWriteResults() throws Exception {
     String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
     DatabaseTestHelper.createTable(dataSource, firstTableName);
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index 67cd1db..536026a 100644
--- 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++ 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -75,6 +75,7 @@ import org.apache.beam.sdk.schemas.transforms.Select;
 import org.apache.beam.sdk.testing.ExpectedLogs;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -89,6 +90,7 @@ import org.hamcrest.Description;
 import org.hamcrest.TypeSafeMatcher;
 import org.joda.time.DateTime;
 import org.joda.time.Duration;
+import org.joda.time.Instant;
 import org.joda.time.LocalDate;
 import org.joda.time.chrono.ISOChronology;
 import org.junit.BeforeClass;
@@ -529,6 +531,31 @@ public class JdbcIOTest implements Serializable {
   }
 
   @Test
+  public void testWriteWithAutosharding() throws Exception {
+    String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+    DatabaseTestHelper.createTable(DATA_SOURCE, tableName);
+    TestStream.Builder<KV<Integer, String>> ts =
+        TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
+            .advanceWatermarkTo(Instant.now());
+
+    try {
+      List<KV<Integer, String>> data = getDataToWrite(EXPECTED_ROW_COUNT);
+      for (KV<Integer, String> elm : data) {
+        ts = ts.addElements(elm);
+      }
+      pipeline
+          .apply(ts.advanceWatermarkToInfinity())
+          .apply(getJdbcWrite(tableName).withAutoSharding());
+
+      pipeline.run().waitUntilFinish();
+
+      assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT);
+    } finally {
+      DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
+    }
+  }
+
+  @Test
   public void testWriteWithWriteResults() throws Exception {
     String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
     DatabaseTestHelper.createTable(DATA_SOURCE, firstTableName);
@@ -548,6 +575,9 @@ public class JdbcIOTest implements Serializable {
                       }));
       resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER);
 
+      PAssert.thatSingleton(resultSetCollection.apply(Count.globally()))
+          .isEqualTo((long) EXPECTED_ROW_COUNT);
+
       List<JdbcTestHelper.TestDto> expectedResult = new ArrayList<>();
       for (int i = 0; i < EXPECTED_ROW_COUNT; i++) {
         expectedResult.add(new 
JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT));

Reply via email to