This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 85817b6b343 Add partitioned reads to JDBC SchemaIO (#25577)
85817b6b343 is described below
commit 85817b6b34380a27ec534a7aac32e3342535d9f6
Author: Yi Hu <[email protected]>
AuthorDate: Wed Feb 22 12:19:17 2023 -0500
Add partitioned reads to JDBC SchemaIO (#25577)
* Fix up the JDBCSchemaIO to support partitioned reads on a column (I
believe numeric and datetime are currently supported). Start adding a
JdbcPartitionedReadSchemaTransformProvider as a more generic SchemaTransform.
This fits better with the SchemaTransform approach as the partitioned read is
actually a different transform entirely from the non-partitioned version.
* Removed the PartitionedReadSchemaTransformProvider pending further
discussion. Added a test to the Python side that should exercise this pathway
(though this is difficult to fully verify). Verified that it is actually run
during tests and that it will fail if something is very wrong though.
* address comments
* Support Int16 type in schema
* Fix pylint; fix unsafe cast; fix test
---------
Co-authored-by: Byron Ellis <[email protected]>
---
.../beam/sdk/io/jdbc/JdbcSchemaIOProvider.java | 63 +++++++---
.../beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java | 129 +++++++++++++++++++++
sdks/python/apache_beam/coders/coder_impl.pxd | 4 +
sdks/python/apache_beam/coders/coder_impl.py | 18 ++-
sdks/python/apache_beam/coders/coders.py | 19 +++
.../apache_beam/coders/coders_test_common.py | 1 +
sdks/python/apache_beam/coders/row_coder.py | 3 +
sdks/python/apache_beam/coders/slow_stream.py | 6 +
sdks/python/apache_beam/coders/stream.pxd | 3 +
sdks/python/apache_beam/coders/stream.pyx | 16 +++
sdks/python/apache_beam/coders/stream_test.py | 9 ++
.../io/external/xlang_jdbcio_it_test.py | 18 +++
sdks/python/apache_beam/io/jdbc.py | 37 +++---
13 files changed, 292 insertions(+), 34 deletions(-)
diff --git
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java
index 77ec4082f6f..b5969e31809 100644
---
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java
+++
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProvider.java
@@ -67,6 +67,10 @@ public class JdbcSchemaIOProvider implements
SchemaIOProvider {
.addNullableField("fetchSize", FieldType.INT16)
.addNullableField("outputParallelization", FieldType.BOOLEAN)
.addNullableField("autosharding", FieldType.BOOLEAN)
+ // Partitioning support. If you specify a partition column we will use
that instead of
+ // readQuery
+ .addNullableField("partitionColumn", FieldType.STRING)
+ .addNullableField("partitions", FieldType.INT16)
.build();
}
@@ -110,26 +114,49 @@ public class JdbcSchemaIOProvider implements
SchemaIOProvider {
return new PTransform<PBegin, PCollection<Row>>() {
@Override
public PCollection<Row> expand(PBegin input) {
- @Nullable String readQuery = config.getString("readQuery");
- if (readQuery == null) {
- readQuery = String.format("SELECT * FROM %s", location);
- }
-
- JdbcIO.ReadRows readRows =
- JdbcIO.readRows()
- .withDataSourceConfiguration(getDataSourceConfiguration())
- .withQuery(readQuery);
-
- @Nullable Short fetchSize = config.getInt16("fetchSize");
- if (fetchSize != null) {
- readRows = readRows.withFetchSize(fetchSize);
- }
- @Nullable Boolean outputParallelization =
config.getBoolean("outputParallelization");
- if (outputParallelization != null) {
- readRows =
readRows.withOutputParallelization(outputParallelization);
+ // If we define a partition column we need to go a different route
+ @Nullable
+ String partitionColumn =
+ config.getSchema().hasField("partitionColumn")
+ ? config.getString("partitionColumn")
+ : null;
+ if (partitionColumn != null) {
+ JdbcIO.ReadWithPartitions<Row, ?> readRows =
+ JdbcIO.<Row>readWithPartitions()
+ .withDataSourceConfiguration(getDataSourceConfiguration())
+ .withTable(location)
+ .withPartitionColumn(partitionColumn)
+ .withRowOutput();
+ @Nullable Short partitions = config.getInt16("partitions");
+ if (partitions != null) {
+ readRows = readRows.withNumPartitions(partitions);
+ }
+ return input.apply(readRows);
+ } else {
+
+ @Nullable String readQuery = config.getString("readQuery");
+ if (readQuery == null) {
+ readQuery = String.format("SELECT * FROM %s", location);
+ }
+
+ JdbcIO.ReadRows readRows =
+ JdbcIO.readRows()
+ .withDataSourceConfiguration(getDataSourceConfiguration())
+ .withQuery(readQuery);
+
+ @Nullable Short fetchSize = config.getInt16("fetchSize");
+ if (fetchSize != null) {
+ readRows = readRows.withFetchSize(fetchSize);
+ }
+
+ @Nullable Boolean outputParallelization =
config.getBoolean("outputParallelization");
+ if (outputParallelization != null) {
+ readRows =
readRows.withOutputParallelization(outputParallelization);
+ }
+
+ return input.apply(readRows);
}
- return input.apply(readRows);
}
};
}
diff --git
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java
new file mode 100644
index 00000000000..d91eaaef6e6
--- /dev/null
+++
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcSchemaIOProviderTest.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jdbc;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.SQLException;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.io.common.DatabaseTestHelper;
+import org.apache.beam.sdk.io.common.TestRow;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class JdbcSchemaIOProviderTest {
+
+ private static final JdbcIO.DataSourceConfiguration
DATA_SOURCE_CONFIGURATION =
+ JdbcIO.DataSourceConfiguration.create(
+ "org.apache.derby.jdbc.EmbeddedDriver",
"jdbc:derby:memory:testDB;create=true");
+ private static final int EXPECTED_ROW_COUNT = 1000;
+
+ private static final DataSource DATA_SOURCE =
DATA_SOURCE_CONFIGURATION.buildDatasource();
+ private static final String READ_TABLE_NAME =
DatabaseTestHelper.getTestTableName("UT_READ");
+
+ @Rule public final transient TestPipeline pipeline = TestPipeline.create();
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ // by default, derby uses a lock timeout of 60 seconds. In order to speed
up the test
+ // and detect the lock faster, we decrease this timeout
+ System.setProperty("derby.locks.waitTimeout", "2");
+ System.setProperty("derby.stream.error.file", "build/derby.log");
+
+ DatabaseTestHelper.createTable(DATA_SOURCE, READ_TABLE_NAME);
+ addInitialData(DATA_SOURCE, READ_TABLE_NAME);
+ }
+
+ @Test
+ public void testPartitionedRead() {
+ JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();
+
+ Row config =
+ Row.withSchema(provider.configurationSchema())
+ .withFieldValue("driverClassName",
DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
+ .withFieldValue("jdbcUrl",
DATA_SOURCE_CONFIGURATION.getUrl().get())
+ .withFieldValue("username", "")
+ .withFieldValue("password", "")
+ .withFieldValue("partitionColumn", "id")
+ .withFieldValue("partitions", (short) 10)
+ .build();
+ JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
+ provider.from(READ_TABLE_NAME, config, Schema.builder().build());
+ PCollection<Row> output = pipeline.apply(schemaIO.buildReader());
+ Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
+ PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);
+ pipeline.run();
+ }
+
+ // This test shouldn't work because we only support numeric and datetime
columns and we are trying
+ // to use a string
+ // column as our partition source
+ @Test
+ public void testPartitionedReadThatShouldntWork() throws Exception {
+ JdbcSchemaIOProvider provider = new JdbcSchemaIOProvider();
+
+ Row config =
+ Row.withSchema(provider.configurationSchema())
+ .withFieldValue("driverClassName",
DATA_SOURCE_CONFIGURATION.getDriverClassName().get())
+ .withFieldValue("jdbcUrl",
DATA_SOURCE_CONFIGURATION.getUrl().get())
+ .withFieldValue("username", "")
+ .withFieldValue("password", "")
+ .withFieldValue("partitionColumn", "name")
+ .withFieldValue("partitions", (short) 10)
+ .build();
+ JdbcSchemaIOProvider.JdbcSchemaIO schemaIO =
+ provider.from(READ_TABLE_NAME, config, Schema.builder().build());
+ PCollection<Row> output = pipeline.apply(schemaIO.buildReader());
+ Long expected = Long.valueOf(EXPECTED_ROW_COUNT);
+ PAssert.that(output.apply(Count.globally())).containsInAnyOrder(expected);
+ try {
+ pipeline.run();
+ } catch (Exception e) {
+ e.printStackTrace();
+ return;
+ }
+ throw new Exception("Did not throw an exception");
+ }
+
+ /** Create test data that is consistent with that generated by TestRow. */
+ private static void addInitialData(DataSource dataSource, String tableName)
throws SQLException {
+ try (Connection connection = dataSource.getConnection()) {
+ connection.setAutoCommit(false);
+ try (PreparedStatement preparedStatement =
+ connection.prepareStatement(String.format("insert into %s values
(?,?)", tableName))) {
+ for (int i = 0; i < EXPECTED_ROW_COUNT; i++) {
+ preparedStatement.clearParameters();
+ preparedStatement.setInt(1, i);
+ preparedStatement.setString(2, TestRow.getNameForSeed(i));
+ preparedStatement.executeUpdate();
+ }
+ }
+ connection.commit();
+ }
+ }
+}
diff --git a/sdks/python/apache_beam/coders/coder_impl.pxd
b/sdks/python/apache_beam/coders/coder_impl.pxd
index 5714f8beeee..0e6e31d0fc8 100644
--- a/sdks/python/apache_beam/coders/coder_impl.pxd
+++ b/sdks/python/apache_beam/coders/coder_impl.pxd
@@ -109,6 +109,10 @@ cdef class BooleanCoderImpl(CoderImpl):
pass
+cdef class BigEndianShortCoderImpl(StreamCoderImpl):
+ pass
+
+
cdef class SinglePrecisionFloatCoderImpl(StreamCoderImpl):
pass
diff --git a/sdks/python/apache_beam/coders/coder_impl.py
b/sdks/python/apache_beam/coders/coder_impl.py
index 094687ce68d..cccc73662ce 100644
--- a/sdks/python/apache_beam/coders/coder_impl.py
+++ b/sdks/python/apache_beam/coders/coder_impl.py
@@ -758,6 +758,22 @@ class NullableCoderImpl(StreamCoderImpl):
if unused_value is not None else 0)
+class BigEndianShortCoderImpl(StreamCoderImpl):
+ """For internal use only; no backwards-compatibility guarantees."""
+ def encode_to_stream(self, value, out, nested):
+ # type: (int, create_OutputStream, bool) -> None
+ out.write_bigendian_int16(value)
+
+ def decode_from_stream(self, in_stream, nested):
+ # type: (create_InputStream, bool) -> float
+ return in_stream.read_bigendian_int16()
+
+ def estimate_size(self, unused_value, nested=False):
+ # type: (Any, bool) -> int
+ # A short is encoded as 2 bytes, regardless of nesting.
+ return 2
+
+
class SinglePrecisionFloatCoderImpl(StreamCoderImpl):
"""For internal use only; no backwards-compatibility guarantees."""
def encode_to_stream(self, value, out, nested):
@@ -770,7 +786,7 @@ class SinglePrecisionFloatCoderImpl(StreamCoderImpl):
def estimate_size(self, unused_value, nested=False):
# type: (Any, bool) -> int
- # A double is encoded as 8 bytes, regardless of nesting.
+ # A float is encoded as 4 bytes, regardless of nesting.
return 4
diff --git a/sdks/python/apache_beam/coders/coders.py
b/sdks/python/apache_beam/coders/coders.py
index 25fabc951c5..d4ca99b80fb 100644
--- a/sdks/python/apache_beam/coders/coders.py
+++ b/sdks/python/apache_beam/coders/coders.py
@@ -682,6 +682,25 @@ class VarIntCoder(FastCoder):
Coder.register_structured_urn(common_urns.coders.VARINT.urn, VarIntCoder)
+class BigEndianShortCoder(FastCoder):
+ """A coder used for big-endian int16 values."""
+ def _create_impl(self):
+ return coder_impl.BigEndianShortCoderImpl()
+
+ def is_deterministic(self):
+ # type: () -> bool
+ return True
+
+ def to_type_hint(self):
+ return int
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+ def __hash__(self):
+ return hash(type(self))
+
+
class SinglePrecisionFloatCoder(FastCoder):
"""A coder used for single-precision floating-point values."""
def _create_impl(self):
diff --git a/sdks/python/apache_beam/coders/coders_test_common.py
b/sdks/python/apache_beam/coders/coders_test_common.py
index a0bec891bdf..7adb06cb287 100644
--- a/sdks/python/apache_beam/coders/coders_test_common.py
+++ b/sdks/python/apache_beam/coders/coders_test_common.py
@@ -160,6 +160,7 @@ class CodersTest(unittest.TestCase):
coders.ListLikeCoder,
coders.ProtoCoder,
coders.ProtoPlusCoder,
+ coders.BigEndianShortCoder,
coders.SinglePrecisionFloatCoder,
coders.ToBytesCoder,
coders.BigIntegerCoder, # tested in DecimalCoder
diff --git a/sdks/python/apache_beam/coders/row_coder.py
b/sdks/python/apache_beam/coders/row_coder.py
index 9dd4dcd9f63..19424fa1f12 100644
--- a/sdks/python/apache_beam/coders/row_coder.py
+++ b/sdks/python/apache_beam/coders/row_coder.py
@@ -22,6 +22,7 @@ from google.protobuf import json_format
from apache_beam.coders import typecoders
from apache_beam.coders.coder_impl import LogicalTypeCoderImpl
from apache_beam.coders.coder_impl import RowCoderImpl
+from apache_beam.coders.coders import BigEndianShortCoder
from apache_beam.coders.coders import BooleanCoder
from apache_beam.coders.coders import BytesCoder
from apache_beam.coders.coders import Coder
@@ -153,6 +154,8 @@ def _nonnull_coder_from_type(field_type):
if type_info == "atomic_type":
if field_type.atomic_type in (schema_pb2.INT32, schema_pb2.INT64):
return VarIntCoder()
+ if field_type.atomic_type == schema_pb2.INT16:
+ return BigEndianShortCoder()
elif field_type.atomic_type == schema_pb2.FLOAT:
return SinglePrecisionFloatCoder()
elif field_type.atomic_type == schema_pb2.DOUBLE:
diff --git a/sdks/python/apache_beam/coders/slow_stream.py
b/sdks/python/apache_beam/coders/slow_stream.py
index 11ccf7fd2e3..71a5b45d769 100644
--- a/sdks/python/apache_beam/coders/slow_stream.py
+++ b/sdks/python/apache_beam/coders/slow_stream.py
@@ -69,6 +69,9 @@ class OutputStream(object):
def write_bigendian_int32(self, v):
self.write(struct.pack('>i', v))
+ def write_bigendian_int16(self, v):
+ self.write(struct.pack('>h', v))
+
def write_bigendian_double(self, v):
self.write(struct.pack('>d', v))
@@ -172,6 +175,9 @@ class InputStream(object):
def read_bigendian_int32(self):
return struct.unpack('>i', self.read(4))[0]
+ def read_bigendian_int16(self):
+ return struct.unpack('>h', self.read(2))[0]
+
def read_bigendian_double(self):
return struct.unpack('>d', self.read(8))[0]
diff --git a/sdks/python/apache_beam/coders/stream.pxd
b/sdks/python/apache_beam/coders/stream.pxd
index fc179bb8c1b..97d66aa089a 100644
--- a/sdks/python/apache_beam/coders/stream.pxd
+++ b/sdks/python/apache_beam/coders/stream.pxd
@@ -29,6 +29,7 @@ cdef class OutputStream(object):
cpdef write_bigendian_int64(self, libc.stdint.int64_t signed_v)
cpdef write_bigendian_uint64(self, libc.stdint.uint64_t signed_v)
cpdef write_bigendian_int32(self, libc.stdint.int32_t signed_v)
+ cpdef write_bigendian_int16(self, libc.stdint.int16_t signed_v)
cpdef write_bigendian_double(self, double d)
cpdef write_bigendian_float(self, float d)
@@ -46,6 +47,7 @@ cdef class ByteCountingOutputStream(OutputStream):
cpdef write_bigendian_int64(self, libc.stdint.int64_t val)
cpdef write_bigendian_uint64(self, libc.stdint.uint64_t val)
cpdef write_bigendian_int32(self, libc.stdint.int32_t val)
+ cpdef write_bigendian_int16(self, libc.stdint.int16_t val)
cpdef size_t get_count(self)
cpdef bytes get(self)
@@ -62,6 +64,7 @@ cdef class InputStream(object):
cpdef libc.stdint.int64_t read_bigendian_int64(self) except? -1
cpdef libc.stdint.uint64_t read_bigendian_uint64(self) except? -1
cpdef libc.stdint.int32_t read_bigendian_int32(self) except? -1
+ cpdef libc.stdint.int16_t read_bigendian_int16(self) except? -1
cpdef double read_bigendian_double(self) except? -1
cpdef float read_bigendian_float(self) except? -1
cpdef bytes read_all(self, bint nested=*)
diff --git a/sdks/python/apache_beam/coders/stream.pyx
b/sdks/python/apache_beam/coders/stream.pyx
index 14536b007cc..8f941c151bd 100644
--- a/sdks/python/apache_beam/coders/stream.pyx
+++ b/sdks/python/apache_beam/coders/stream.pyx
@@ -101,6 +101,14 @@ cdef class OutputStream(object):
self.data[self.pos + 3] = <unsigned char>(v )
self.pos += 4
+ cpdef write_bigendian_int16(self, libc.stdint.int16_t signed_v):
+ cdef libc.stdint.uint16_t v = signed_v
+ if self.buffer_size < self.pos + 2:
+ self.extend(2)
+ self.data[self.pos ] = <unsigned char>(v >> 8)
+ self.data[self.pos + 1] = <unsigned char>(v )
+ self.pos += 2
+
cpdef write_bigendian_double(self, double d):
self.write_bigendian_int64((<libc.stdint.int64_t*><char*>&d)[0])
@@ -157,6 +165,9 @@ cdef class ByteCountingOutputStream(OutputStream):
cpdef write_bigendian_int32(self, libc.stdint.int32_t _):
self.count += 4
+ cpdef write_bigendian_int16(self, libc.stdint.int16_t _):
+ self.count += 2
+
cpdef size_t get_count(self):
return self.count
@@ -237,6 +248,11 @@ cdef class InputStream(object):
| <libc.stdint.uint32_t><unsigned char>self.allc[self.pos - 3] << 16
| <libc.stdint.uint32_t><unsigned char>self.allc[self.pos - 4] << 24)
+ cpdef libc.stdint.int16_t read_bigendian_int16(self) except? -1:
+ self.pos += 2
+ return (<unsigned char>self.allc[self.pos - 1]
+ | <libc.stdint.uint16_t><unsigned char>self.allc[self.pos - 2] << 8)
+
cpdef double read_bigendian_double(self) except? -1:
cdef libc.stdint.int64_t as_long = self.read_bigendian_int64()
return (<double*><char*>&as_long)[0]
diff --git a/sdks/python/apache_beam/coders/stream_test.py
b/sdks/python/apache_beam/coders/stream_test.py
index 35b64eb9581..57662056b2a 100644
--- a/sdks/python/apache_beam/coders/stream_test.py
+++ b/sdks/python/apache_beam/coders/stream_test.py
@@ -139,6 +139,15 @@ class StreamTest(unittest.TestCase):
for v in values:
self.assertEqual(v, in_s.read_bigendian_int32())
+ def test_read_write_bigendian_int16(self):
+ values = 0, 1, -1, 2**15 - 1, -2**15, int(2**13 * math.pi)
+ out_s = self.OutputStream()
+ for v in values:
+ out_s.write_bigendian_int16(v)
+ in_s = self.InputStream(out_s.get())
+ for v in values:
+ self.assertEqual(v, in_s.read_bigendian_int16())
+
def test_byte_counting(self):
bc_s = self.ByteCountingOutputStream()
self.assertEqual(0, bc_s.get_count())
diff --git a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py
b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py
index 1dcb56c51ec..ed8745ec2ac 100644
--- a/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_jdbcio_it_test.py
@@ -201,6 +201,24 @@ class CrossLanguageJdbcIOTest(unittest.TestCase):
assert_that(result, equal_to(expected_row))
+ # Try the same read using the partitioned reader code path.
+ # Outputs should be the same.
+ with TestPipeline() as p:
+ p.not_use_test_runner_api = True
+ result = (
+ p
+ | 'Partitioned read from jdbc' >> ReadFromJdbc(
+ table_name=table_name,
+ partition_column='f_id',
+ partitions=3,
+ driver_class_name=self.driver_class_name,
+ jdbc_url=self.jdbc_url,
+ username=self.username,
+ password=self.password,
+ classpath=classpath))
+
+ assert_that(result, equal_to(expected_row))
+
# Creating a container with testcontainers sometimes raises ReadTimeout
# error. In java there are 2 retries set by default.
def start_db_container(self, retries, container_init):
diff --git a/sdks/python/apache_beam/io/jdbc.py
b/sdks/python/apache_beam/io/jdbc.py
index 85b80fdea0e..aa539871601 100644
--- a/sdks/python/apache_beam/io/jdbc.py
+++ b/sdks/python/apache_beam/io/jdbc.py
@@ -88,6 +88,8 @@
import typing
+import numpy as np
+
from apache_beam.coders import RowCoder
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external import ExternalTransform
@@ -113,19 +115,16 @@ JdbcConfigSchema = typing.NamedTuple(
Config = typing.NamedTuple(
'Config',
- [
- ('driver_class_name', str),
- ('jdbc_url', str),
- ('username', str),
- ('password', str),
- ('connection_properties', typing.Optional[str]),
- ('connection_init_sqls', typing.Optional[typing.List[str]]),
- ('read_query', typing.Optional[str]),
- ('write_statement', typing.Optional[str]),
- ('fetch_size', typing.Optional[int]),
- ('output_parallelization', typing.Optional[bool]),
- ('autosharding', typing.Optional[bool]),
- ],
+ [('driver_class_name', str), ('jdbc_url', str), ('username', str),
+ ('password', str), ('connection_properties', typing.Optional[str]),
+ ('connection_init_sqls', typing.Optional[typing.List[str]]),
+ ('read_query', typing.Optional[str]),
+ ('write_statement', typing.Optional[str]),
+ ('fetch_size', typing.Optional[int]),
+ ('output_parallelization', typing.Optional[bool]),
+ ('autosharding', typing.Optional[bool]),
+ ('partition_column', typing.Optional[str]),
+ ('partitions', typing.Optional[np.int16])],
)
DEFAULT_JDBC_CLASSPATH = ['org.postgresql:postgresql:42.2.16']
@@ -226,7 +225,8 @@ class WriteToJdbc(ExternalTransform):
fetch_size=None,
output_parallelization=None,
autosharding=autosharding,
- ))),
+ partitions=None,
+ partition_column=None))),
),
expansion_service or default_io_expansion_service(classpath),
)
@@ -273,6 +273,8 @@ class ReadFromJdbc(ExternalTransform):
query=None,
output_parallelization=None,
fetch_size=None,
+ partition_column=None,
+ partitions=None,
connection_properties=None,
connection_init_sqls=None,
expansion_service=None,
@@ -288,6 +290,10 @@ class ReadFromJdbc(ExternalTransform):
:param query: sql query to be executed
:param output_parallelization: is output parallelization on
:param fetch_size: how many rows to fetch
+ :param partition_column: enable partitioned reads by splitting on this
+ column
+ :param partitions: override the default number of splits when using
+ partition_column
:param connection_properties: properties of the jdbc connection
passed as string with format
[propertyName=property;]*
@@ -324,7 +330,8 @@ class ReadFromJdbc(ExternalTransform):
fetch_size=fetch_size,
output_parallelization=output_parallelization,
autosharding=None,
- ))),
+ partition_column=partition_column,
+ partitions=partitions))),
),
expansion_service or default_io_expansion_service(classpath),
)