This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4cc0bccdf56 Fix excessive checkStateNotNull in JdbcUtil (#25847)
4cc0bccdf56 is described below

commit 4cc0bccdf56fca6a219bbd210fcbbc32666b738d
Author: Yi Hu <[email protected]>
AuthorDate: Wed Mar 15 11:29:52 2023 -0400

    Fix excessive checkStateNotNull in JdbcUtil (#25847)
    
    * Also fix cannot write null values in derby (unit test)
---
 .../java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java |  19 ++--
 .../org/apache/beam/sdk/io/jdbc/SchemaUtil.java    |   6 +-
 .../org/apache/beam/sdk/io/jdbc/JdbcIOTest.java    | 110 +++++++++++++++++----
 3 files changed, 107 insertions(+), 28 deletions(-)

diff --git 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
index 04a24dd0daf..41be4074ef6 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
@@ -85,7 +85,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Byte value = element.getByte(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.TINYINT);
                     } else {
                       ps.setByte(i + 1, value);
                     }
@@ -95,7 +95,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Short value = element.getInt16(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.SMALLINT);
                     } else {
                       ps.setInt(i + 1, value);
                     }
@@ -105,7 +105,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Long value = element.getInt64(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.BIGINT);
                     } else {
                       ps.setLong(i + 1, value);
                     }
@@ -119,7 +119,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Float value = element.getFloat(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.FLOAT);
                     } else {
                       ps.setFloat(i + 1, value);
                     }
@@ -129,7 +129,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Double value = 
element.getDouble(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.DOUBLE);
                     } else {
                       ps.setDouble(i + 1, value);
                     }
@@ -145,7 +145,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Boolean value = 
element.getBoolean(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.BOOLEAN);
                     } else {
                       ps.setBoolean(i + 1, value);
                     }
@@ -156,7 +156,7 @@ class JdbcUtil {
                   (element, ps, i, fieldWithIndex) -> {
                     Integer value = 
element.getInt32(fieldWithIndex.getIndex());
                     if (value == null) {
-                      setNullToPreparedStatement(ps, i);
+                      setNullToPreparedStatement(ps, i, JDBCType.INTEGER);
                     } else {
                       ps.setInt(i + 1, value);
                     }
@@ -267,8 +267,9 @@ class JdbcUtil {
     ps.setArray(i + 1, null);
   }
 
-  static void setNullToPreparedStatement(PreparedStatement ps, int i) throws 
SQLException {
-    ps.setNull(i + 1, JDBCType.NULL.getVendorTypeNumber());
+  static void setNullToPreparedStatement(PreparedStatement ps, int i, JDBCType 
type)
+      throws SQLException {
+    ps.setNull(i + 1, type.getVendorTypeNumber());
   }
 
   static class BeamRowPreparedStatementSetter implements 
JdbcIO.PreparedStatementSetter<Row> {
diff --git 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
index 05414fee0be..458b0c2c82a 100644
--- 
a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
+++ 
b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
@@ -27,7 +27,6 @@ import static java.sql.JDBCType.NVARCHAR;
 import static java.sql.JDBCType.VARBINARY;
 import static java.sql.JDBCType.VARCHAR;
 import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull;
-import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull;
 import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
 
 import java.io.Serializable;
@@ -304,7 +303,10 @@ class SchemaUtil {
     } else {
       ResultSetFieldExtractor extractor = 
createFieldExtractor(fieldType.getBaseType());
       return (rs, index) -> {
-        BaseT v = checkStateNotNull((BaseT) extractor.extract(rs, index));
+        BaseT v = (BaseT) extractor.extract(rs, index);
+        if (v == null) {
+          return null;
+        }
         return fieldType.toInputType(v);
       };
     }
diff --git 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index 72fb670b2ac..0e1874a08b7 100644
--- 
a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++ 
b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -17,7 +17,6 @@
  */
 package org.apache.beam.sdk.io.jdbc;
 
-import static java.sql.JDBCType.NULL;
 import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.closeTo;
@@ -119,6 +118,8 @@ public class JdbcIOTest implements Serializable {
 
   @Rule public final transient TestPipeline pipeline = TestPipeline.create();
 
+  @Rule public final transient TestPipeline secondPipeline = 
TestPipeline.create();
+
   @Rule public final transient ExpectedLogs expectedLogs = 
ExpectedLogs.none(JdbcIO.class);
 
   @Rule public transient ExpectedException thrown = ExpectedException.none();
@@ -991,13 +992,13 @@ public class JdbcIOTest implements Serializable {
         .set(row, psMocked, 10, 
SchemaUtil.FieldWithIndex.of(schema.getField(8), 8));
 
     // primitive
-    verify(psMocked, times(1)).setNull(1, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(2, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(3, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(4, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(5, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(6, NULL.getVendorTypeNumber());
-    verify(psMocked, times(1)).setNull(7, NULL.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(1, 
JDBCType.BIGINT.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(2, 
JDBCType.BOOLEAN.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(3, 
JDBCType.DOUBLE.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(4, 
JDBCType.FLOAT.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(5, 
JDBCType.INTEGER.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(6, 
JDBCType.SMALLINT.getVendorTypeNumber());
+    verify(psMocked, times(1)).setNull(7, 
JDBCType.TINYINT.getVendorTypeNumber());
     // reference
     verify(psMocked, times(1)).setBytes(8, null);
     verify(psMocked, times(1)).setString(9, null);
@@ -1095,20 +1096,32 @@ public class JdbcIOTest implements Serializable {
     verify(psMocked, times(1)).setArray(1, arrayMocked);
   }
 
-  private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
+  private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema, 
boolean hasNulls) {
 
     ArrayList<Row> data = new ArrayList<>();
+    int numFields = schema.getFields().size();
     for (int i = 0; i < rowsToAdd; i++) {
-
-      Row row =
-          schema.getFields().stream()
-              .map(field -> dummyFieldValue(field.getType()))
-              .collect(Row.toRow(schema));
-      data.add(row);
+      Row.Builder builder = Row.withSchema(schema);
+      for (int j = 0; j < numFields; j++) {
+        if (hasNulls && i % numFields == j && 
schema.getField(j).getType().getNullable()) {
+          builder.addValue(null);
+        } else {
+          builder.addValue(dummyFieldValue(schema.getField(j).getType()));
+        }
+      }
+      data.add(builder.build());
     }
     return data;
   }
 
+  private static ArrayList<Row> getRowsToWrite(long rowsToAdd, Schema schema) {
+    return getRowsToWrite(rowsToAdd, schema, false);
+  }
+
+  private static ArrayList<Row> getNullableRowsToWrite(long rowsToAdd, Schema 
schema) {
+    return getRowsToWrite(rowsToAdd, schema, true);
+  }
+
   private static ArrayList<RowWithSchema> getRowsWithSchemaToWrite(long 
rowsToAdd) {
 
     ArrayList<RowWithSchema> data = new ArrayList<>();
@@ -1118,7 +1131,8 @@ public class JdbcIOTest implements Serializable {
     return data;
   }
 
-  private static Object dummyFieldValue(Schema.FieldType fieldType) {
+  private static Object dummyFieldValue(Schema.FieldType maybeNullableType) {
+    Schema.FieldType fieldType = maybeNullableType.withNullable(false);
     long epochMilli = 1558719710000L;
     if (fieldType.equals(Schema.FieldType.STRING)) {
       return "string value";
@@ -1134,7 +1148,12 @@ public class JdbcIOTest implements Serializable {
       return Long.MAX_VALUE;
     } else if (fieldType.equals(Schema.FieldType.FLOAT)) {
       return 15.5F;
-    } else if (fieldType.equals(Schema.FieldType.DECIMAL)) {
+    } else if (fieldType.equals(Schema.FieldType.DECIMAL)
+        || (fieldType.getLogicalType() != null
+            && fieldType
+                .getLogicalType()
+                .getIdentifier()
+                .equals(FixedPrecisionNumeric.IDENTIFIER))) {
       return BigDecimal.ONE;
     } else if (fieldType.equals(LogicalTypes.JDBC_DATE_TYPE)) {
       return new DateTime(epochMilli, 
ISOChronology.getInstanceUTC()).withTimeAtStartOfDay();
@@ -1326,6 +1345,63 @@ public class JdbcIOTest implements Serializable {
     pipeline.run().waitUntilFinish();
   }
 
+  @Test
+  public void testWriteReadNullableTypes() throws SQLException {
+    // first setup data
+    Schema.Builder schemaBuilder = Schema.builder();
+    schemaBuilder.addField("column_id", FieldType.INT32.withNullable(false));
+    schemaBuilder.addField("column_bigint", 
Schema.FieldType.INT64.withNullable(true));
+    schemaBuilder.addField("column_boolean", 
FieldType.BOOLEAN.withNullable(true));
+    schemaBuilder.addField("column_float", 
Schema.FieldType.FLOAT.withNullable(true));
+    schemaBuilder.addField("column_double", 
Schema.FieldType.DOUBLE.withNullable(true));
+    schemaBuilder.addField(
+        "column_decimal",
+        FieldType.logicalType(FixedPrecisionNumeric.of(13, 
0)).withNullable(true));
+    Schema schema = schemaBuilder.build();
+
+    // some types not supported in derby (e.g. tinyint) are not tested here
+    String tableName = 
DatabaseTestHelper.getTestTableName("UT_READ_NULLABLE_LG");
+    StringBuilder stmt = new StringBuilder("CREATE TABLE ");
+    stmt.append(tableName);
+    stmt.append(" (");
+    stmt.append("column_id      INTEGER NOT NULL,"); // Integer
+    stmt.append("column_bigint  BIGINT,"); // int64
+    stmt.append("column_boolean BOOLEAN,"); // boolean
+    stmt.append("column_float  REAL,"); // float
+    stmt.append("column_double  DOUBLE PRECISION,"); // double
+    stmt.append("column_decimal    DECIMAL(13,0)"); // BigDecimal
+    stmt.append(" )");
+    DatabaseTestHelper.createTableWithStatement(DATA_SOURCE, stmt.toString());
+    final int rowsToAdd = 10;
+    try {
+      // run write pipeline
+      ArrayList<Row> data = getNullableRowsToWrite(rowsToAdd, schema);
+      pipeline
+          .apply(Create.of(data))
+          .setRowSchema(schema)
+          .apply(
+              JdbcIO.<Row>write()
+                  .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
+                  .withBatchSize(10L)
+                  .withTable(tableName));
+      pipeline.run();
+      assertRowCount(DATA_SOURCE, tableName, rowsToAdd);
+
+      // run read pipeline
+      PCollection<Row> rows =
+          secondPipeline.apply(
+              JdbcIO.readRows()
+                  .withDataSourceConfiguration(DATA_SOURCE_CONFIGURATION)
+                  .withQuery("SELECT * FROM " + tableName));
+      PAssert.thatSingleton(rows.apply("Count All", 
Count.globally())).isEqualTo((long) rowsToAdd);
+      PAssert.that(rows).containsInAnyOrder(data);
+
+      secondPipeline.run();
+    } finally {
+      DatabaseTestHelper.deleteTable(DATA_SOURCE, tableName);
+    }
+  }
+
   @Test
   public void testPartitioningLongs() {
     PCollection<KV<Long, Long>> ranges =

Reply via email to