This is an automated email from the ASF dual-hosted git repository.
yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 4cc0bccdf56 Fix excessive checkStateNotNull in JdbcUtil (#25847)
4cc0bccdf56 is described below
commit 4cc0bccdf56fca6a219bbd210fcbbc32666b738d
Author: Yi Hu <[email protected]>
AuthorDate: Wed Mar 15 11:29:52 2023 -0400
Fix excessive checkStateNotNull in JdbcUtil (#25847)
* Also fix cannot write null values in derby (unit test)
---
.../java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java | 19 ++--
.../org/apache/beam/sdk/io/jdbc/SchemaUtil.java | 6 +-
.../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java | 110 +++++++++++++++++----
3 files changed, 107 insertions(+), 28 deletions(-)
diff --git
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
index 04a24dd0daf..41be4074ef6 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
@@ -85,7 +85,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Byte value = element.getByte(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.TINYINT);
} else {
ps.setByte(i + 1, value);
}
@@ -95,7 +95,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Short value = element.getInt16(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.SMALLINT);
} else {
ps.setInt(i + 1, value);
}
@@ -105,7 +105,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Long value = element.getInt64(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.BIGINT);
} else {
ps.setLong(i + 1, value);
}
@@ -119,7 +119,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Float value = element.getFloat(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.FLOAT);
} else {
ps.setFloat(i + 1, value);
}
@@ -129,7 +129,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Double value =
element.getDouble(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.DOUBLE);
} else {
ps.setDouble(i + 1, value);
}
@@ -145,7 +145,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Boolean value =
element.getBoolean(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.BOOLEAN);
} else {
ps.setBoolean(i + 1, value);
}
@@ -156,7 +156,7 @@ class JdbcUtil {
(element, ps, i, fieldWithIndex) -> {
Integer value =
element.getInt32(fieldWithIndex.getIndex());
if (value == null) {
- setNullToPreparedStatement(ps, i);
+ setNullToPreparedStatement(ps, i, JDBCType.INTEGER);
} else {
ps.setInt(i + 1, value);
}
@@ -267,8 +267,9 @@ class JdbcUtil {
ps.setArray(i + 1, null);
}
- static void setNullToPreparedStatement(PreparedStatement ps, int i) throws
SQLException {
- ps.setNull(i + 1, JDBCType.NULL.getVendorTypeNumber());
+ static void setNullToPreparedStatement(PreparedStatement ps, int i, JDBCType
type)
+ throws SQLException {
+ ps.setNull(i + 1, type.getVendorTypeNumber());
}
static class BeamRowPreparedStatementSetter implements
JdbcIO.PreparedStatementSetter<Row> {
diff --git
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
index 05414fee0be..458b0c2c82a 100644
---
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
+++
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
@@ -27,7 +27,6 @@ import static java.sql.JDBCType.NVARCHAR;
import static java.sql.JDBCType.VARBINARY;
import static java.sql.JDBCType.VARCHAR;
import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
-import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
import static
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import java.io.Serializable;
@@ -304,7 +303,10 @@ class SchemaUtil {
} else {
ResultSetFieldExtractor extractor =
createFieldExtractor(fieldType.getBaseType());
return (rs, index) -> {
- BaseT v = checkStateNotNull((BaseT) extractor.extract(rs, index));
+ BaseT v = (BaseT) extractor.extract(rs, index);
+ if (v == null) {
+ return null;
+ }
return fieldType.toInputType(v);
};
}
diff --git
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index 72fb670b2ac..0e1874a08b7 100644
---
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -17,7 +17,6 @@
*/
package org.apache.beam.sdk.io.jdbc;
-import static java.sql.JDBCType.NULL;
import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.closeTo;
@@ -119,6 +118,8 @@ public class JdbcIOTest implements Serializable {
@Rule public final transient TestPipeline pipeline = TestPipeline.create();
+ @Rule public final transient TestPipeline secondPipeline =
TestPipeline.create();
+
@Rule public final transient ExpectedLogs expectedLogs =
ExpectedLogs.none(JdbcIO.class);
@Rule public transient ExpectedException thrown = ExpectedException.none();
@@ -991,13 +992,13 @@ public class JdbcIOTest implements Serializable {
.set(row, psMocked, 10,
SchemaUtil.FieldWithIndex.of(schema.getField(8), 8));
// primitive
- verify(psMocked, times(1)).setNull(1, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(2, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(3, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(4, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(5, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(6, NULL.getVendorTypeNumber());
- verify(psMocked, times(1)).setNull(7, NULL.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(1,
JDBCType.BIGINT.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(2,
JDBCType.BOOLEAN.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(3,
JDBCType.DOUBLE.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(4,
JDBCType.FLOAT.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(5,
JDBCType.INTEGER.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(6,
JDBCType.SMALLINT.getVendorTypeNumber());
+ verify(psMocked, times(1)).setNull(7,
JDBCType.TINYINT.getVendorTypeNumber());
// reference
verify(psMocked, times(1)).setBytes(8, null);
verify(psMocked, times(1)).setString(9, null);
@@ -1095,20 +1096,32 @@ public class JdbcIOTest implements Serializable {
verify(psMocked, times(1)).setArray(1, arrayMocked);
}
- private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
+ private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema,
boolean hasNulls) {
ArrayList<Row> data = new ArrayList<>();
+ int numFields = schema.getFields().size();
for (int i = 0; i < rowsToAdd; i++) {
-
- Row row =
- schema.getFields().stream()
- .map(field -> dummyFieldValue(field.getType()))
- .collect(Row.toRow(schema));
- data.add(row);
+ Row.Builder builder = Row.withSchema(schema);
+ for (int j = 0; j < numFields; j++) {
+ if (hasNulls && i % numFields == j &&
schema.getField(j).getType().getNullable()) {
+ builder.addValue(null);
+ } else {
+ builder.addValue(dummyFieldValue(schema.getField(j).getType()));
+ }
+ }
+ data.add(builder.build());
}
return data;
}
+ private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
+ return getRowsToWrite(rowsToAdd, schema, false);
+ }
+
+ private static ArrayList<Row> getNullableRowsToWrite(long rowsToAdd, Schema
schema) {
+ return getRowsToWrite(rowsToAdd, schema, true);
+ }
+
private static ArrayList<RowWithSchema> getRowsWithSchemaToWrite(long
rowsToAdd) {
ArrayList<RowWithSchema> data = new ArrayList<>();
@@ -1118,7 +1131,8 @@ public class JdbcIOTest implements Serializable {
return data;
}
- private static Object dummyFieldValue(Schema.FieldType fieldType) {
+ private static Object dummyFieldValue(Schema.FieldType maybeNullableType) {
+ Schema.FieldType fieldType = maybeNullableType.withNullable(false);
long epochMilli = 1558719710000L;
if (fieldType.equals(Schema.FieldType.STRING)) {
return "string value";
@@ -1134,7 +1148,12 @@ public class JdbcIOTest implements Serializable {
return Long.MAX_VALUE;
} else if (fieldType.equals(Schema.FieldType.FLOAT)) {
return 15.5F;
- } else if (fieldType.equals(Schema.FieldType.DECIMAL)) {
+ } else if (fieldType.equals(Schema.FieldType.DECIMAL)
+ || (fieldType.getLogicalType() != null
+ && fieldType
+ .getLogicalType()
+ .getIdentifier()
+ .equals(FixedPrecisionNumeric.IDENTIFIER))) {
return BigDecimal.ONE;
} else if (fieldType.equals(LogicalTypes.JDBC_DATE_TYPE)) {
return new DateTime(epochMilli,
ISOChronology.getInstanceUTC()).withTimeAtStartOfDay();
@@ -1326,6 +1345,63 @@ public class JdbcIOTest implements Serializable {
pipeline.run().waitUntilFinish();
}
+ @Test
+ public void testWriteReadNullableTypes() throws SQLException {
+ // first setup data
+ Schema.Builder schemaBuilder = Schema.builder();
+ schemaBuilder.addField("column_id", FieldType.INT32.withNullable(false));
+ schemaBuilder.addField("column_bigint",
Schema.FieldType.INT64.withNullable(true));
+ schemaBuilder.addField("column_boolean",
FieldType.BOOLEAN.withNullable(true));
+ schemaBuilder.addField("column_float",
Schema.FieldType.FLOAT.withNullable(true));
+ schemaBuilder.addField("column_double",
Schema.FieldType.DOUBLE.withNullable(true));
+ schemaBuilder.addField(
+ "column_decimal",
+ FieldType.logicalType(FixedPrecisionNumeric.of(13,
0)).withNullable(true));
+ Schema schema = schemaBuilder.build();
+
+ // some types not supported in derby (e.g. tinyint) are not tested here
+ String tableName =
DatabaseTestHelper.getTestTableName("UT_READ_NULLABLE_LG");
+ StringBuilder stmt = new StringBuilder("CREATE TABLE ");
+ stmt.append(tableName);
+ stmt.append(" (");
+ stmt.append("column_id INTEGER NOT NULL,"); // Integer
+ stmt.append("column_bigint BIGINT,"); // int64
+ stmt.append("column_boolean BOOLEAN,"); // boolean
+ stmt.append("column_float REAL,"); // float
+ stmt.append("column_double DOUBLE PRECISION,"); // double
+ stmt.append("column_decimal DECIMAL(13,0)"); // BigDecimal
+ stmt.append(" )");
+ DatabaseTestHelper.createTableWithStatement(DATA_SOURCE, stmt.toString());
+ final int rowsToAdd = 10;
+ try {
+ // run write pipeline
+ ArrayList<Row> data = getNullableRowsToWrite(rowsToAdd, schema);
+ pipeline
+ .apply(Create.of(data))
+ .setRowSchema(schema)
+ .apply(
+ JdbcIO.<Row>write()
+ .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
+ .withBatchSize(10L)
+ .withTable(tableName));
+ pipeline.run();
+ assertRowCount(DATA_SOURCE, tableName, rowsToAdd);
+
+ // run read pipeline
+ PCollection<Row> rows =
+ secondPipeline.apply(
+ JdbcIO.readRows()
+ .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
+ .withQuery("SELECT * FROM " + tableName));
+ PAssert.thatSingleton(rows.apply("Count All",
Count.globally())).isEqualTo((long) rowsToAdd);
+ PAssert.that(rows).containsInAnyOrder(data);
+
+ secondPipeline.run();
+ } finally {
+ DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
+ }
+ }
+
@Test
public void testPartitioningLongs() {
PCollection<KV<Long, Long>> ranges =