This is an automated email from the ASF dual-hosted git repository.
pabloem pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new e18e2b9 Merge pull request #16579 from Revert "Revert "Merge pull
request #15863 from [BEAM-13184] Autoshard…
e18e2b9 is described below
commit e18e2b9ba26bee41e93c3109318672eb7b1b26bc
Author: Pablo <[email protected]>
AuthorDate: Thu Jan 27 19:41:53 2022 -0800
Merge pull request #16579 from Revert "Revert "Merge pull request #15863
from [BEAM-13184] Autoshard…
* Revert "Revert "Merge pull request #15863 from [BEAM-13184] Autosharding
for JdbcIO.write* transforms""
This reverts commit 421bc8068fc561a358cfbf6c9842408672872120.
* Using batchSize to define element batch size
* Handle corner case for null list
---
.../java/org/apache/beam/sdk/io/jdbc/JdbcIO.java | 112 +++++++++++++++++----
.../java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java | 39 +++++++
.../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java | 30 ++++++
3 files changed, 164 insertions(+), 17 deletions(-)
diff --git
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 0f6b9c3..14f5e69 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -66,15 +66,19 @@ import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
+import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.BackOff;
import org.apache.beam.sdk.util.BackOffUtils;
import org.apache.beam.sdk.util.FluentBackoff;
@@ -82,6 +86,7 @@ import org.apache.beam.sdk.util.Sleeper;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.Row;
@@ -96,6 +101,7 @@ import org.apache.commons.pool2.impl.GenericObjectPool;
import org.apache.commons.pool2.impl.GenericObjectPoolConfig;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Duration;
+import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -1318,6 +1324,11 @@ public class JdbcIO {
this.inner = inner;
}
+ /** See {@link WriteVoid#withAutoSharding()}. */
+ public Write<T> withAutoSharding() {
+ return new Write<>(inner.withAutoSharding());
+ }
+
/** See {@link
WriteVoid#withDataSourceConfiguration(DataSourceConfiguration)}. */
public Write<T> withDataSourceConfiguration(DataSourceConfiguration
config) {
return new Write<>(inner.withDataSourceConfiguration(config));
@@ -1393,6 +1404,7 @@ public class JdbcIO {
.setPreparedStatementSetter(inner.getPreparedStatementSetter())
.setStatement(inner.getStatement())
.setTable(inner.getTable())
+ .setAutoSharding(inner.getAutoSharding())
.build();
}
@@ -1408,6 +1420,51 @@ public class JdbcIO {
}
}
+ /* The maximum number of elements that will be included in a batch. */
+
+ static <T> PCollection<Iterable<T>> batchElements(
+ PCollection<T> input, Boolean withAutoSharding, long batchSize) {
+ PCollection<Iterable<T>> iterables;
+ if (input.isBounded() == IsBounded.UNBOUNDED && withAutoSharding != null
&& withAutoSharding) {
+ iterables =
+ input
+ .apply(WithKeys.<String, T>of(""))
+ .apply(
+ GroupIntoBatches.<String, T>ofSize(batchSize)
+ .withMaxBufferingDuration(Duration.millis(200))
+ .withShardedKey())
+ .apply(Values.create());
+ } else {
+ iterables =
+ input.apply(
+ ParDo.of(
+ new DoFn<T, Iterable<T>>() {
+ List<T> outputList;
+
+ @ProcessElement
+ public void process(ProcessContext c) {
+ if (outputList == null) {
+ outputList = new ArrayList<>();
+ }
+ outputList.add(c.element());
+ if (outputList.size() > batchSize) {
+ c.output(outputList);
+ outputList = null;
+ }
+ }
+
+ @FinishBundle
+ public void finish(FinishBundleContext c) {
+ if (outputList != null && outputList.size() > 0) {
+ c.output(outputList, Instant.now(),
GlobalWindow.INSTANCE);
+ }
+ outputList = null;
+ }
+ }));
+ }
+ return iterables;
+ }
+
/** Interface implemented by functions that sets prepared statement data. */
@FunctionalInterface
interface PreparedStatementSetCaller extends Serializable {
@@ -1430,6 +1487,8 @@ public class JdbcIO {
@AutoValue
public abstract static class WriteWithResults<T, V extends JdbcWriteResult>
extends PTransform<PCollection<T>, PCollection<V>> {
+ abstract @Nullable Boolean getAutoSharding();
+
abstract @Nullable SerializableFunction<Void, DataSource>
getDataSourceProviderFn();
abstract @Nullable ValueProvider<String> getStatement();
@@ -1451,6 +1510,8 @@ public class JdbcIO {
abstract Builder<T, V> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);
+ abstract Builder<T, V> setAutoSharding(Boolean autoSharding);
+
abstract Builder<T, V> setStatement(ValueProvider<String> statement);
abstract Builder<T, V>
setPreparedStatementSetter(PreparedStatementSetter<T> setter);
@@ -1487,6 +1548,11 @@ public class JdbcIO {
return toBuilder().setPreparedStatementSetter(setter).build();
}
+ /** If true, enables using a dynamically determined number of shards to
write. */
+ public WriteWithResults<T, V> withAutoSharding() {
+ return toBuilder().setAutoSharding(true).build();
+ }
+
/**
* When a SQL exception occurs, {@link Write} uses this {@link
RetryStrategy} to determine if it
* will retry the statements. If {@link RetryStrategy#apply(SQLException)}
returns {@code true},
@@ -1549,8 +1615,15 @@ public class JdbcIO {
checkArgument(
(getDataSourceProviderFn() != null),
"withDataSourceConfiguration() or withDataSourceProviderFn() is
required");
-
- return input.apply(
+ checkArgument(
+ getAutoSharding() == null
+ || (getAutoSharding() && input.isBounded() !=
IsBounded.UNBOUNDED),
+ "Autosharding is only supported for streaming pipelines.");
+ ;
+
+ PCollection<Iterable<T>> iterables =
+ JdbcIO.<T>batchElements(input, getAutoSharding(),
DEFAULT_BATCH_SIZE);
+ return iterables.apply(
ParDo.of(
new WriteFn<T, V>(
WriteFnSpec.builder()
@@ -1573,6 +1646,8 @@ public class JdbcIO {
@AutoValue
public abstract static class WriteVoid<T> extends PTransform<PCollection<T>,
PCollection<Void>> {
+ abstract @Nullable Boolean getAutoSharding();
+
abstract @Nullable SerializableFunction<Void, DataSource>
getDataSourceProviderFn();
abstract @Nullable ValueProvider<String> getStatement();
@@ -1591,6 +1666,8 @@ public class JdbcIO {
@AutoValue.Builder
abstract static class Builder<T> {
+ abstract Builder<T> setAutoSharding(Boolean autoSharding);
+
abstract Builder<T> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);
@@ -1609,6 +1686,11 @@ public class JdbcIO {
abstract WriteVoid<T> build();
}
+ /** If true, enables using a dynamically determined number of shards to
write. */
+ public WriteVoid<T> withAutoSharding() {
+ return toBuilder().setAutoSharding(true).build();
+ }
+
public WriteVoid<T> withDataSourceConfiguration(DataSourceConfiguration
config) {
return withDataSourceProviderFn(new
DataSourceProviderFromDataSourceConfiguration(config));
}
@@ -1708,7 +1790,11 @@ public class JdbcIO {
checkArgument(
spec.getPreparedStatementSetter() != null,
"withPreparedStatementSetter() is required");
}
- return input
+
+ PCollection<Iterable<T>> iterables =
+ JdbcIO.<T>batchElements(input, getAutoSharding(), getBatchSize());
+
+ return iterables
.apply(
ParDo.of(
new WriteFn<T, Void>(
@@ -1955,7 +2041,7 @@ public class JdbcIO {
* @param <T>
* @param <V>
*/
- static class WriteFn<T, V> extends DoFn<T, V> {
+ static class WriteFn<T, V> extends DoFn<Iterable<T>, V> {
@AutoValue
abstract static class WriteFnSpec<T, V> implements Serializable,
HasDisplayData {
@@ -2045,7 +2131,6 @@ public class JdbcIO {
private Connection connection;
private PreparedStatement preparedStatement;
private static FluentBackoff retryBackOff;
- private final List<T> records = new ArrayList<>();
public WriteFn(WriteFnSpec<T, V> spec) {
this.spec = spec;
@@ -2085,17 +2170,12 @@ public class JdbcIO {
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
- T record = context.element();
- records.add(record);
- if (records.size() >= spec.getBatchSize()) {
- executeBatch(context);
- }
+ executeBatch(context, context.element());
}
@FinishBundle
public void finishBundle() throws Exception {
// We pass a null context because we only execute a final batch for
WriteVoid cases.
- executeBatch(null);
cleanUpStatementAndConnection();
}
@@ -2124,11 +2204,8 @@ public class JdbcIO {
}
}
- private void executeBatch(ProcessContext context)
+ private void executeBatch(ProcessContext context, Iterable<T> records)
throws SQLException, IOException, InterruptedException {
- if (records.isEmpty()) {
- return;
- }
Long startTimeNs = System.nanoTime();
Sleeper sleeper = Sleeper.DEFAULT;
BackOff backoff = retryBackOff.backoff();
@@ -2137,8 +2214,10 @@ public class JdbcIO {
getConnection().prepareStatement(spec.getStatement().get())) {
try {
// add each record in the statement batch
+ int recordsInBatch = 0;
for (T record : records) {
processRecord(record, preparedStatement, context);
+ recordsInBatch += 1;
}
if (!spec.getReturnResults()) {
// execute the batch
@@ -2146,7 +2225,7 @@ public class JdbcIO {
// commit the changes
getConnection().commit();
}
- RECORDS_PER_BATCH.update(records.size());
+ RECORDS_PER_BATCH.update(recordsInBatch);
MS_PER_BATCH.update(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() -
startTimeNs));
break;
} catch (SQLException exception) {
@@ -2164,7 +2243,6 @@ public class JdbcIO {
}
}
}
- records.clear();
}
private void processRecord(T record, PreparedStatement preparedStatement,
ProcessContext c) {
diff --git
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
index 8cebbbd..59bc764 100644
--- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
+++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOIT.java
@@ -32,6 +32,9 @@ import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
import org.apache.beam.sdk.io.common.HashingFn;
@@ -39,6 +42,7 @@ import
org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.testutils.NamedTestResult;
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
@@ -51,6 +55,7 @@ import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.joda.time.Instant;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
@@ -255,6 +260,40 @@ public class JdbcIOIT {
}
@Test
+ public void testWriteWithAutosharding() throws Exception {
+ String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+ DatabaseTestHelper.createTable(dataSource, firstTableName);
+ try {
+ List<KV<Integer, String>> data = getTestDataToWrite(EXPECTED_ROW_COUNT);
+ TestStream.Builder<KV<Integer, String>> ts =
+ TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
+ .advanceWatermarkTo(Instant.now());
+ for (KV<Integer, String> elm : data) {
+ ts.addElements(elm);
+ }
+
+ PCollection<KV<Integer, String>> dataCollection =
+ pipelineWrite.apply(ts.advanceWatermarkToInfinity());
+ dataCollection.apply(
+ JdbcIO.<KV<Integer, String>>write()
+ .withDataSourceProviderFn(voidInput -> dataSource)
+ .withStatement(String.format("insert into %s values(?, ?)
returning *", tableName))
+ .withAutoSharding()
+ .withPreparedStatementSetter(
+ (element, statement) -> {
+ statement.setInt(1, element.getKey());
+ statement.setString(2, element.getValue());
+ }));
+
+ pipelineWrite.run().waitUntilFinish();
+
+ runRead();
+ } finally {
+ DatabaseTestHelper.deleteTable(dataSource, firstTableName);
+ }
+ }
+
+ @Test
public void testWriteWithWriteResults() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
DatabaseTestHelper.createTable(dataSource, firstTableName);
diff --git
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index 67cd1db..536026a 100644
---
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -75,6 +75,7 @@ import org.apache.beam.sdk.schemas.transforms.Select;
import org.apache.beam.sdk.testing.ExpectedLogs;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.testing.TestStream;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
@@ -89,6 +90,7 @@ import org.hamcrest.Description;
import org.hamcrest.TypeSafeMatcher;
import org.joda.time.DateTime;
import org.joda.time.Duration;
+import org.joda.time.Instant;
import org.joda.time.LocalDate;
import org.joda.time.chrono.ISOChronology;
import org.junit.BeforeClass;
@@ -529,6 +531,31 @@ public class JdbcIOTest implements Serializable {
}
@Test
+ public void testWriteWithAutosharding() throws Exception {
+ String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+ DatabaseTestHelper.createTable(DATA_SOURCE, tableName);
+ TestStream.Builder<KV<Integer, String>> ts =
+ TestStream.create(KvCoder.of(VarIntCoder.of(), StringUtf8Coder.of()))
+ .advanceWatermarkTo(Instant.now());
+
+ try {
+ List<KV<Integer, String>> data = getDataToWrite(EXPECTED_ROW_COUNT);
+ for (KV<Integer, String> elm : data) {
+ ts = ts.addElements(elm);
+ }
+ pipeline
+ .apply(ts.advanceWatermarkToInfinity())
+ .apply(getJdbcWrite(tableName).withAutoSharding());
+
+ pipeline.run().waitUntilFinish();
+
+ assertRowCount(DATA_SOURCE, tableName, EXPECTED_ROW_COUNT);
+ } finally {
+ DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
+ }
+ }
+
+ @Test
public void testWriteWithWriteResults() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
DatabaseTestHelper.createTable(DATA_SOURCE, firstTableName);
@@ -548,6 +575,9 @@ public class JdbcIOTest implements Serializable {
}));
resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER);
+ PAssert.thatSingleton(resultSetCollection.apply(Count.globally()))
+ .isEqualTo((long) EXPECTED_ROW_COUNT);
+
List<JdbcTestHelper.TestDto> expectedResult = new ArrayList<>();
for (int i = 0; i < EXPECTED_ROW_COUNT; i++) {
expectedResult.add(new
JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT));