mosche commented on a change in pull request #15848:
URL: https://github.com/apache/beam/pull/15848#discussion_r799531475



##########
File path: 
sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
##########
@@ -333,31 +343,206 @@ private static Calendar 
withTimestampAndTimezone(DateTime dateTime) {
     return calendar;
   }
 
+  /**
+   * A helper for {@link ReadWithPartitions} that handles range calculations.
+   *
+   * @param <PartitionT>
+   */
+  interface JdbcReadWithPartitionsHelper<PartitionT>
+      extends PreparedStatementSetter<KV<PartitionT, PartitionT>>,
+          RowMapper<KV<Long, KV<PartitionT, PartitionT>>> {
+    Iterable<KV<PartitionT, PartitionT>> calculateRanges(
+        PartitionT lowerBound, PartitionT upperBound, Long partitions);
+
+    @Override
+    void setParameters(KV<PartitionT, PartitionT> element, PreparedStatement 
preparedStatement);
+
+    @Override
+    KV<Long, KV<PartitionT, PartitionT>> mapRow(ResultSet resultSet) throws 
Exception;
+  }
+
   /** Create partitions on a table. */
-  static class PartitioningFn extends DoFn<KV<Integer, KV<Long, Long>>, 
KV<String, Long>> {
+  static class PartitioningFn<T> extends DoFn<KV<Long, KV<T, T>>, KV<T, T>> {
+    private static final Logger LOG = 
LoggerFactory.getLogger(PartitioningFn.class);
+    final TypeDescriptor<T> partitioningColumnType;
+
+    PartitioningFn(TypeDescriptor<T> partitioningColumnType) {
+      this.partitioningColumnType = partitioningColumnType;
+    }
+
     @ProcessElement
     public void processElement(ProcessContext c) {
-      Integer numPartitions = c.element().getKey();
-      Long lowerBound = c.element().getValue().getKey();
-      Long upperBound = c.element().getValue().getValue();
-      if (lowerBound > upperBound) {
-        throw new RuntimeException(
-            String.format(
-                "Lower bound [%s] is higher than upper bound [%s]", 
lowerBound, upperBound));
-      }
-      long stride = (upperBound - lowerBound) / numPartitions + 1;
-      for (long i = lowerBound; i < upperBound - stride; i += stride) {
-        String range = String.format("%s,%s", i, i + stride);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
-      }
-      if (upperBound - lowerBound > stride * (numPartitions - 1)) {
-        long indexFrom = (numPartitions - 1) * stride;
-        long indexTo = upperBound + 1;
-        String range = String.format("%s,%s", indexFrom, indexTo);
-        KV<String, Long> kvRange = KV.of(range, 1L);
-        c.output(kvRange);
+      T lowerBound = c.element().getValue().getKey();
+      T upperBound = c.element().getValue().getValue();
+      JdbcReadWithPartitionsHelper<T> helper =
+          (JdbcReadWithPartitionsHelper<T>) 
PRESET_HELPERS.get(partitioningColumnType.getRawType());
+      List<KV<T, T>> ranges =
+          Lists.newArrayList(helper.calculateRanges(lowerBound, upperBound, 
c.element().getKey()));
+      LOG.warn("Total of {} ranges: {}", ranges.size(), ranges);
+      for (KV<T, T> e : ranges) {
+        c.output(e);
       }
     }
   }
+
+  public static final Map<Class<?>, JdbcReadWithPartitionsHelper<?>> 
PRESET_HELPERS =

Review comment:
       That would be great I think, thanks. I personally think such tests can 
serve as the best possible documentation for other devs and the ones in 
JdbcIOTest are already on a fairly high level as they run a pipeline. 
   
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to