mosche commented on a change in pull request #15848:
URL: https://github.com/apache/beam/pull/15848#discussion_r801420304



##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
##########
@@ -333,31 +343,154 @@ private static Calendar 
withTimestampAndTimezone(DateTime dateTime) {
     return calendar;
   }
 
+  /**
+   * A helper for {@link ReadWithPartitions} that handles range calculations.
+   *
+   * @param <PartitionT>
+   */
+  interface JdbcReadWithPartitionsHelper<PartitionT>
+      extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
+          RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {
+    static <T> JdbcReadWithPartitionsHelper<T> 
getPartitionsHelper(TypeDescriptor<T> type) {
+      // This cast is unchecked, thus this is a small type-checking risk. We 
just need
+      // to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are 
matched
+      // in type from their Key and their Value.
+      return (JdbcReadWithPartitionsHelper<T>) 
PRESET_HELPERS.get(type.getRawType());
+    }
+
+    Iterable<KV<PartitionT, PartitionT>> calculateRanges(
+        PartitionT lowerBound, PartitionT upperBound, Long partitions);
+
+    @Override
+    void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement 
preparedStatement);
+
+    @Override
+    KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws 
Exception;
+  }
+
   /** Create partitions on a table. */
-  static class PartitioningFn extends DoFn<KV<Integer, KV<Long, Long>>, 
KV<String, Long>> {
+  static class PartitioningFn<T> extends DoFn<KV<Long, KV<T, T>>, KV<T, T>> {
+    private static final Logger LOG = 
LoggerFactory.getLogger(PartitioningFn.class);
+    final TypeDescriptor<T> partitioningColumnType;
+
+    PartitioningFn(TypeDescriptor<T> partitioningColumnType) {
+      this.partitioningColumnType = partitioningColumnType;
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c) {
-      Integer numPartitions = c.element().getKey();
-      Long lowerBound = c.element().getValue().getKey();
-      Long upperBound = c.element().getValue().getValue();
-      if (lowerBound > upperBound) {
-        throw new RuntimeException(
-            String.format(
-                "Lower bound [%s] is higher than upper bound [%s]", 
lowerBound, upperBound));
-      }
-      long stride = (upperBound - lowerBound) / numPartitions + 1;
-      for (long i = lowerBound; i < upperBound - stride; i += stride) {
-        String range = String.format("%s,%s", i, i + stride);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
-      }
-      if (upperBound - lowerBound > stride * (numPartitions - 1)) {
-        long indexFrom = (numPartitions - 1) * stride;
-        long indexTo = upperBound + 1;
-        String range = String.format("%s,%s", indexFrom, indexTo);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
+      T lowerBound = c.element().getValue().getKey();
+      T upperBound = c.element().getValue().getValue();
+      JdbcReadWithPartitionsHelper<T> helper =
+          
JdbcReadWithPartitionsHelper.getPartitionsHelper(partitioningColumnType);
+      List<KV<T, T>> ranges =
+          Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, 
c.element().getKey()));
+      LOG.warn("Total of {} ranges: {}", ranges.size(), ranges);
+      for (KV<T, T> e : ranges) {
+        c.output(e);
       }
     }
   }
+
+  public static final Map<Class<?>, JdbcReadWithPartitionsHelper<?>> 
PRESET_HELPERS =
+      ImmutableMap.of(
+          Long.class,
+          new JdbcReadWithPartitionsHelper<Long>() {
+            @Override
+            public Iterable<KV<Long, Long>> calculateRanges(
+                Long lowerBound, Long upperBound, Long partitions) {
+              List<KV<Long, Long>> ranges = new ArrayList<>();
+              // We divide by partitions FIRST to make sure that we can cover 
the whole LONG range.
+              // If we substract first, then we may end up with Long.MAX - 
Long.MIN, which is 2*MAX,
+              // and we'd have trouble with the pipeline.
+              long stride = (upperBound / partitions - lowerBound / 
partitions) + 1;
+              long highest = lowerBound;
+              for (long i = lowerBound; i < upperBound - stride; i += stride) {
+                ranges.add(KV.of(i, i + stride));
+                highest = i + stride;
+              }
+              if (upperBound - lowerBound > stride * (ranges.size() - 1)) {
+                long indexFrom = highest;
+                long indexTo = upperBound + 1;
+                ranges.add(KV.of(indexFrom, indexTo));
+              }
+              return ranges;
+            }
+
+            @Override
+            public void setParameters(KV<Long, Long> element, 
PreparedStatement preparedStatement) {
+              try {
+                preparedStatement.setLong(1, element.getKey());
+                preparedStatement.setLong(2, element.getValue());
+              } catch (SQLException e) {
+                throw new RuntimeException(e);
+              }
+            }
+
+            @Override
+            public KV<Long, KV<Long, Long>> mapRow(ResultSet resultSet) throws 
Exception {
+              if (resultSet.getMetaData().getColumnCount() == 3) {
+                return KV.of(
+                    resultSet.getLong(3), KV.of(resultSet.getLong(1), 
resultSet.getLong(2)));
+              } else {
+                return KV.of(0L, KV.of(resultSet.getLong(1), 
resultSet.getLong(2)));
+              }
+            }
+          },
+          DateTime.class,
+          new JdbcReadWithPartitionsHelper<DateTime>() {
+            @Override
+            public Iterable<KV<DateTime, DateTime>> calculateRanges(
+                DateTime lowerBound, DateTime upperBound, Long partitions) {
+              final List<KV<DateTime, DateTime>> result = new ArrayList<>();
+
+              final long intervalMillis = upperBound.getMillis() - 
lowerBound.getMillis();
+              final long strideMillis =

Review comment:
       nitpick, how about
   ```java
   final Duration stride = Duration.millis(Math.max(1, intervalMillis / 
partitions))
   ```

##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
##########
@@ -333,31 +343,154 @@ private static Calendar 
withTimestampAndTimezone(DateTime dateTime) {
     return calendar;
   }
 
+  /**
+   * A helper for {@link ReadWithPartitions} that handles range calculations.
+   *
+   * @param <PartitionT>
+   */
+  interface JdbcReadWithPartitionsHelper<PartitionT>
+      extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
+          RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {
+    static <T> JdbcReadWithPartitionsHelper<T> 
getPartitionsHelper(TypeDescriptor<T> type) {
+      // This cast is unchecked, thus this is a small type-checking risk. We 
just need
+      // to make sure that all preset helpers in `JdbcUtil.PRESET_HELPERS` are 
matched
+      // in type from their Key and their Value.
+      return (JdbcReadWithPartitionsHelper<T>) 
PRESET_HELPERS.get(type.getRawType());
+    }
+
+    Iterable<KV<PartitionT, PartitionT>> calculateRanges(
+        PartitionT lowerBound, PartitionT upperBound, Long partitions);
+
+    @Override
+    void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement 
preparedStatement);
+
+    @Override
+    KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws 
Exception;
+  }
+
   /** Create partitions on a table. */
-  static class PartitioningFn extends DoFn<KV<Integer, KV<Long, Long>>, 
KV<String, Long>> {
+  static class PartitioningFn<T> extends DoFn<KV<Long, KV<T, T>>, KV<T, T>> {
+    private static final Logger LOG = 
LoggerFactory.getLogger(PartitioningFn.class);
+    final TypeDescriptor<T> partitioningColumnType;
+
+    PartitioningFn(TypeDescriptor<T> partitioningColumnType) {
+      this.partitioningColumnType = partitioningColumnType;
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c) {
-      Integer numPartitions = c.element().getKey();
-      Long lowerBound = c.element().getValue().getKey();
-      Long upperBound = c.element().getValue().getValue();
-      if (lowerBound > upperBound) {
-        throw new RuntimeException(
-            String.format(
-                "Lower bound [%s] is higher than upper bound [%s]", 
lowerBound, upperBound));
-      }
-      long stride = (upperBound - lowerBound) / numPartitions + 1;
-      for (long i = lowerBound; i < upperBound - stride; i += stride) {
-        String range = String.format("%s,%s", i, i + stride);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
-      }
-      if (upperBound - lowerBound > stride * (numPartitions - 1)) {
-        long indexFrom = (numPartitions - 1) * stride;
-        long indexTo = upperBound + 1;
-        String range = String.format("%s,%s", indexFrom, indexTo);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
+      T lowerBound = c.element().getValue().getKey();
+      T upperBound = c.element().getValue().getValue();
+      JdbcReadWithPartitionsHelper<T> helper =
+          
JdbcReadWithPartitionsHelper.getPartitionsHelper(partitioningColumnType);
+      List<KV<T, T>> ranges =
+          Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, 
c.element().getKey()));
+      LOG.warn("Total of {} ranges: {}", ranges.size(), ranges);
+      for (KV<T, T> e : ranges) {
+        c.output(e);
       }
     }
   }
+
+  public static final Map<Class<?>, JdbcReadWithPartitionsHelper<?>> 
PRESET_HELPERS =
+      ImmutableMap.of(
+          Long.class,
+          new JdbcReadWithPartitionsHelper<Long>() {
+            @Override
+            public Iterable<KV<Long, Long>> calculateRanges(
+                Long lowerBound, Long upperBound, Long partitions) {
+              List<KV<Long, Long>> ranges = new ArrayList<>();
+              // We divide by partitions FIRST to make sure that we can cover 
the whole LONG range.
+              // If we substract first, then we may end up with Long.MAX - 
Long.MIN, which is 2*MAX,
+              // and we'd have trouble with the pipeline.
+              long stride = (upperBound / partitions - lowerBound / 
partitions) + 1;
+              long highest = lowerBound;
+              for (long i = lowerBound; i < upperBound - stride; i += stride) {
+                ranges.add(KV.of(i, i + stride));
+                highest = i + stride;
+              }
+              if (upperBound - lowerBound > stride * (ranges.size() - 1)) {
+                long indexFrom = highest;
+                long indexTo = upperBound + 1;
+                ranges.add(KV.of(indexFrom, indexTo));
+              }

Review comment:
       Nitpick, would suggest to simplify as follows:
   ```              
   if(highest < upperBound + 1){
     ranges.add(KV.of(highest, upperBound + 1));
   }

##########
File path: 
sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java
##########
@@ -28,6 +44,66 @@
 @RunWith(JUnit4.class)
 public class JdbcUtilTest {
 
+  static final JdbcReadWithPartitionsHelper<String> 
PROTOTYPE_STRING_PARTITIONER =

Review comment:
       Are you planning to park the code here? Would suggest to at least add a 
comment with the follow up ticket

##########
File path: 
sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOAutoPartitioningIT.java
##########
@@ -0,0 +1,403 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jdbc;
+
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.Objects;
+import java.util.Random;
+import javax.sql.DataSource;
+import org.apache.beam.sdk.PipelineResult;
+import org.apache.beam.sdk.io.GenerateSequence;
+import org.apache.beam.sdk.io.common.DatabaseTestHelper;
+import org.apache.beam.sdk.io.jdbc.JdbcIO.RowMapper;
+import org.apache.beam.sdk.metrics.Distribution;
+import org.apache.beam.sdk.metrics.Metrics;
+import org.apache.beam.sdk.schemas.JavaFieldSchema;
+import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
+import org.apache.beam.sdk.schemas.annotations.SchemaCreate;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.Row;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import org.apache.beam.sdk.values.TypeDescriptors;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.joda.time.DateTime;
+import org.joda.time.Duration;
+import org.joda.time.Instant;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestRule;
+import org.junit.runner.Description;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.model.Statement;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testcontainers.containers.JdbcDatabaseContainer;
+import org.testcontainers.containers.MySQLContainer;
+import org.testcontainers.containers.PostgreSQLContainer;
+
+/** A test of {@link org.apache.beam.sdk.io.jdbc.JdbcIO} on test containers. */
+@RunWith(Parameterized.class)
+public class JdbcIOAutoPartitioningIT {
+  private static final Logger LOG = 
LoggerFactory.getLogger(JdbcIOAutoPartitioningIT.class);
+
+  public static final Integer NUM_ROWS = 1_000;
+  public static final String TABLE_NAME = "baseTable";
+
+  @ClassRule public static TestPipeline pipelineWrite = TestPipeline.create();
+  @Rule public TestPipeline pipelineRead = TestPipeline.create();
+
+  @Parameterized.Parameters(name = "{0}")
+  public static Iterable<String> params() {
+    return Lists.newArrayList("mysql", "postgres");
+  }
+
+  @Parameterized.Parameter(0)
+  public String dbms;
+
+  public static JdbcDatabaseContainer<?> getDb(String dbName) {
+    if (dbName.equals("mysql")) {
+      return mysql;
+    } else {
+      return postgres;
+    }
+  }
+
+  // We need to implement this retrying rule because we're running ~18 
pipelines that connect to
+  // two databases in a few seconds. This may cause the databases to be 
overwhelmed, and reject
+  // connections or error out. By using this rule, we ensure that each 
pipeline is trued twice so
+  // that flakiness from databases being overwhelmed can be managed.
+  @Rule
+  public TestRule retryRule =
+      new TestRule() {
+        // We establish max number of retries at 2.
+        public final int maxRetries = 2;
+
+        @Override
+        public Statement apply(Statement base, Description description) {
+          return new Statement() {
+            @Override
+            public void evaluate() throws Throwable {
+              Throwable caughtThrowable = null;
+              // implement retry logic here
+              for (int i = 0; i < maxRetries; i++) {
+                try {
+                  pipelineRead.apply(base, description);
+                  // base.evaluate();

Review comment:
       Be careful @pabloem , with this commented no tests are actually running. 
Everything just succeeds...
   Almost missed this, but was caught by surprise to see the tests for string 
partitioning pass.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to