[ 
https://issues.apache.org/jira/browse/BEAM-3485?focusedWorklogId=96069&page=com.atlassian.jira.plugin.system.issuetabpanels:worklog-tabpanel#worklog-96069
 ]

ASF GitHub Bot logged work on BEAM-3485:
----------------------------------------

                Author: ASF GitHub Bot
            Created on: 27/Apr/18 13:22
            Start Date: 27/Apr/18 13:22
    Worklog Time Spent: 10m 
      Work Description: iemejia closed pull request #5124: [BEAM-3485] Fix 
split generation for Cassandra clusters
URL: https://github.com/apache/beam/pull/5124
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
index 56db599823c..3e9f0142299 100644
--- 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
+++ 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraIO.java
@@ -80,7 +80,6 @@
  */
 @Experimental(Experimental.Kind.SOURCE_SINK)
 public class CassandraIO {
-
   private CassandraIO() {}
 
   /**
@@ -103,7 +102,6 @@ private CassandraIO() {}
    */
   @AutoValue
   public abstract static class Read<T> extends PTransform<PBegin, 
PCollection<T>> {
-
     @Nullable abstract List<String> hosts();
     @Nullable abstract Integer port();
     @Nullable abstract String keyspace();
@@ -114,6 +112,7 @@ private CassandraIO() {}
     @Nullable abstract String password();
     @Nullable abstract String localDc();
     @Nullable abstract String consistencyLevel();
+    @Nullable abstract Integer minNumberOfSplits();
     @Nullable abstract CassandraService<T> cassandraService();
     abstract Builder<T> builder();
 
@@ -197,6 +196,18 @@ private CassandraIO() {}
       return builder().setConsistencyLevel(consistencyLevel).build();
     }
 
+    /**
+     * It's possible that system.size_estimates isn't populated or that the 
number of splits
+     * computed by Beam is still to low for Cassandra to handle it.
+     * This setting allows to enforce a minimum number of splits in case Beam 
cannot compute
+     * it correctly.
+     */
+    public Read<T> withMinNumberOfSplits(Integer minNumberOfSplits) {
+      checkArgument(minNumberOfSplits != null, "minNumberOfSplits can not be 
null");
+      checkArgument(minNumberOfSplits > 0, "minNumberOfSplits must be greater 
than 0");
+      return builder().setMinNumberOfSplits(minNumberOfSplits).build();
+    }
+
     /**
      * Specify an instance of {@link CassandraService} used to connect and 
read from Cassandra
      * database.
@@ -231,6 +242,7 @@ private CassandraIO() {}
       abstract Builder<T> setPassword(String password);
       abstract Builder<T> setLocalDc(String localDc);
       abstract Builder<T> setConsistencyLevel(String consistencyLevel);
+      abstract Builder<T> setMinNumberOfSplits(Integer minNumberOfSplits);
       abstract Builder<T> setCassandraService(CassandraService<T> 
cassandraService);
       abstract Read<T> build();
     }
@@ -247,19 +259,16 @@ private CassandraIO() {}
       }
       return new CassandraServiceImpl<>();
     }
-
   }
 
   @VisibleForTesting
   static class CassandraSource<T> extends BoundedSource<T> {
+    final Read<T> spec;
+    final List<String> splitQueries;
 
-    protected final Read<T> spec;
-    protected final String splitQuery;
-
-    CassandraSource(Read<T> spec,
-                    String splitQuery) {
+    CassandraSource(Read<T> spec, List<String> splitQueries) {
       this.spec = spec;
-      this.splitQuery = splitQuery;
+      this.splitQueries = splitQueries;
     }
 
     @Override
@@ -273,15 +282,14 @@ private CassandraIO() {}
     }
 
     @Override
-    public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) throws 
Exception {
+    public long getEstimatedSizeBytes(PipelineOptions pipelineOptions) {
       return spec.getCassandraService().getEstimatedSizeBytes(spec);
     }
 
     @Override
-    public List<BoundedSource<T>> split(long desiredBundleSizeBytes,
-                                                   PipelineOptions 
pipelineOptions) {
-      return spec.getCassandraService()
-          .split(spec, desiredBundleSizeBytes);
+    public List<BoundedSource<T>> split(
+        long desiredBundleSizeBytes, PipelineOptions pipelineOptions) {
+      return spec.getCassandraService().split(spec, desiredBundleSizeBytes);
     }
 
     @Override
diff --git 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImpl.java
 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImpl.java
index 5a52d2cbfa5..31a18193344 100644
--- 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImpl.java
+++ 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImpl.java
@@ -22,25 +22,28 @@
 import com.datastax.driver.core.PlainTextAuthProvider;
 import com.datastax.driver.core.QueryOptions;
 import com.datastax.driver.core.ResultSet;
+import com.datastax.driver.core.ResultSetFuture;
 import com.datastax.driver.core.Row;
 import com.datastax.driver.core.Session;
 import com.datastax.driver.core.policies.DCAwareRoundRobinPolicy;
-import com.datastax.driver.core.policies.RoundRobinPolicy;
 import com.datastax.driver.core.policies.TokenAwarePolicy;
 import com.datastax.driver.core.querybuilder.QueryBuilder;
 import com.datastax.driver.core.querybuilder.Select;
 import com.datastax.driver.mapping.Mapper;
 import com.datastax.driver.mapping.MappingManager;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Iterators;
 import com.google.common.collect.Lists;
 import com.google.common.util.concurrent.ListenableFuture;
 import java.io.IOException;
 import java.math.BigInteger;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Iterator;
 import java.util.List;
 import java.util.NoSuchElementException;
 import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
 import org.apache.beam.sdk.io.BoundedSource;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -51,19 +54,13 @@
 public class CassandraServiceImpl<T> implements CassandraService<T> {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(CassandraServiceImpl.class);
-
-  private static final long MIN_TOKEN = Long.MIN_VALUE;
-  private static final long MAX_TOKEN = Long.MAX_VALUE;
-  private static final BigInteger TOTAL_TOKEN_COUNT =
-      BigInteger.valueOf(MAX_TOKEN).subtract(BigInteger.valueOf(MIN_TOKEN));
+  private static final String MURMUR3PARTITIONER = 
"org.apache.cassandra.dht.Murmur3Partitioner";
 
   private class CassandraReaderImpl<T> extends BoundedSource.BoundedReader<T> {
 
     private final CassandraIO.CassandraSource<T> source;
-
     private Cluster cluster;
     private Session session;
-    private ResultSet resultSet;
     private Iterator<T> iterator;
     private T current;
 
@@ -77,12 +74,23 @@ public boolean start() throws IOException {
       cluster = getCluster(source.spec.hosts(), source.spec.port(), 
source.spec.username(),
           source.spec.password(), source.spec.localDc(), 
source.spec.consistencyLevel());
       session = cluster.connect();
-      LOG.debug("Query: " + source.splitQuery);
-      resultSet = session.execute(source.splitQuery);
+      LOG.debug("Queries: " + source.splitQueries);
+      List<ResultSetFuture> futures = Lists.newArrayList();
+      for (String query:source.splitQueries) {
+        futures.add(session.executeAsync(query));
+      }
 
       final MappingManager mappingManager = new MappingManager(session);
       Mapper mapper = mappingManager.mapper(source.spec.entity());
-      iterator = mapper.map(resultSet).iterator();
+
+      for (ResultSetFuture result:futures) {
+        if (iterator == null) {
+          iterator = mapper.map(result.getUninterruptibly()).iterator();
+        } else {
+          iterator = Iterators.concat(iterator, 
mapper.map(result.getUninterruptibly()).iterator());
+        }
+      }
+
       return advance();
     }
 
@@ -168,13 +176,13 @@ protected static long 
getEstimatedSizeBytes(List<TokenRange> tokenRanges) {
         spec.localDc(), spec.consistencyLevel())) {
       if (isMurmur3Partitioner(cluster)) {
         LOG.info("Murmur3Partitioner detected, splitting");
-        return split(spec, desiredBundleSizeBytes, 
getEstimatedSizeBytes(spec));
+        return split(spec, desiredBundleSizeBytes, 
getEstimatedSizeBytes(spec), cluster);
       } else {
         LOG.warn("Only Murmur3Partitioner is supported for splitting, using an 
unique source for "
             + "the read");
         String splitQuery = QueryBuilder.select().from(spec.keyspace(), 
spec.table()).toString();
         List<BoundedSource<T>> sources = new ArrayList<>();
-        sources.add(new CassandraIO.CassandraSource<>(spec, splitQuery));
+        sources.add(new CassandraIO.CassandraSource<>(spec, 
Arrays.asList(splitQuery)));
         return sources;
       }
     }
@@ -187,7 +195,17 @@ protected static long 
getEstimatedSizeBytes(List<TokenRange> tokenRanges) {
   @VisibleForTesting
   protected List<BoundedSource<T>> split(CassandraIO.Read<T> spec,
                                                 long desiredBundleSizeBytes,
-                                                long estimatedSizeBytes) {
+                                                long estimatedSizeBytes,
+                                                Cluster cluster) {
+    String partitionKey =
+        cluster.getMetadata()
+            .getKeyspace(spec.keyspace())
+            .getTable(spec.table())
+            .getPartitionKey()
+            .stream()
+              .map(partitionKeyColumn -> partitionKeyColumn.getName())
+              .collect(Collectors.joining(","));
+
     long numSplits = 1;
     List<BoundedSource<T>> sourceList = new ArrayList<>();
     if (desiredBundleSizeBytes > 0) {
@@ -198,34 +216,52 @@ protected static long 
getEstimatedSizeBytes(List<TokenRange> tokenRanges) {
       numSplits = 1;
     }
 
-    LOG.info("Number of splits is {}", numSplits);
+    if (null != spec.minNumberOfSplits()) {
+      numSplits = Math.max(numSplits, spec.minNumberOfSplits());
+    }
+    LOG.info("Number of desired splits is {}", numSplits);
 
-    double startRange = MIN_TOKEN;
-    double endRange = MAX_TOKEN;
-    double startToken, endToken;
+    SplitGenerator splitGenerator = new 
SplitGenerator(cluster.getMetadata().getPartitioner());
+    List<BigInteger> tokens = cluster.getMetadata().getTokenRanges().stream()
+        .map(tokenRange -> new 
BigInteger(tokenRange.getEnd().getValue().toString()))
+        .collect(Collectors.toList());
+    List<List<RingRange>> splits = splitGenerator.generateSplits(numSplits, 
tokens);
 
-    endToken = startRange;
-    double incrementValue = endRange - startRange / numSplits;
-    String splitQuery;
-    if (numSplits == 1) {
-      // we have an unique split
-      splitQuery = QueryBuilder.select().from(spec.keyspace(), 
spec.table()).toString();
-      sourceList.add(new CassandraIO.CassandraSource<>(spec, splitQuery));
-    } else {
-      // we have more than one split
-      for (int i = 0; i < numSplits; i++) {
-        startToken = endToken;
-        endToken = startToken + incrementValue;
+    LOG.info("{} splits were actually generated", splits.size());
+
+    for (List<RingRange> split:splits) {
+      List<String> queries = Lists.newArrayList();
+      for (RingRange range : split) {
         Select.Where builder = QueryBuilder.select().from(spec.keyspace(), 
spec.table()).where();
-        if (i > 0) {
-          builder = builder.and(QueryBuilder.gte("token($pk)", startToken));
-        }
-        if (i < (numSplits - 1)) {
-          builder = builder.and(QueryBuilder.lt("token($pk)", endToken));
+        if (range.isWrapping()) {
+          // A wrapping range is one that overlaps from the end of the 
partitioner range and its
+          // start (ie : when the start token of the split is greater than the 
end token)
+          // We need to generate two queries here : one that goes from the 
start token to the end of
+          // the partitioner range, and the other from the start of the 
partitioner range to the
+          // end token of the split.
+          builder = builder.and(QueryBuilder.gte("token(" + partitionKey + 
")", range.getStart()));
+          String query = builder.toString();
+          LOG.info("Cassandra generated read query : {}", query);
+          queries.add(query);
+
+          // Generation of the second query of the wrapping range
+          builder = QueryBuilder.select().from(spec.keyspace(), 
spec.table()).where();
+          builder = builder.and(QueryBuilder.lt("token(" + partitionKey + ")", 
range.getEnd()));
+          query = builder.toString();
+          LOG.info("Cassandra generated read query : {}", query);
+          queries.add(query);
+        } else {
+          builder = builder.and(QueryBuilder.gte("token(" + partitionKey + 
")", range.getStart()));
+          builder = builder.and(QueryBuilder.lt("token(" + partitionKey + ")", 
range.getEnd()));
+          String query = builder.toString();
+          LOG.info("Cassandra generated read query : {}", query);
+          queries.add(query);
         }
-        sourceList.add(new CassandraIO.CassandraSource(spec, 
builder.toString()));
       }
+      sourceList.add(new CassandraIO.CassandraSource(spec, queries));
+      queries = Lists.newArrayList();
     }
+
     return sourceList;
   }
 
@@ -242,13 +278,14 @@ private Cluster getCluster(List<String> hosts, int port, 
String username, String
       builder.withAuthProvider(new PlainTextAuthProvider(username, password));
     }
 
+    DCAwareRoundRobinPolicy.Builder dcAwarePolicyBuilder = new 
DCAwareRoundRobinPolicy.Builder();
     if (localDc != null) {
-      builder.withLoadBalancingPolicy(
-          new TokenAwarePolicy(new 
DCAwareRoundRobinPolicy.Builder().withLocalDc(localDc).build()));
-    } else {
-      builder.withLoadBalancingPolicy(new TokenAwarePolicy(new 
RoundRobinPolicy()));
+      dcAwarePolicyBuilder.withLocalDc(localDc);
     }
 
+    builder.withLoadBalancingPolicy(
+        new TokenAwarePolicy(dcAwarePolicyBuilder.build()));
+
     if (consistencyLevel != null) {
       builder.withQueryOptions(
           new 
QueryOptions().setConsistencyLevel(ConsistencyLevel.valueOf(consistencyLevel)));
@@ -277,8 +314,8 @@ private Cluster getCluster(List<String> hosts, int port, 
String username, String
             new TokenRange(
                 row.getLong("partitions_count"),
                 row.getLong("mean_partition_size"),
-                row.getLong("range_start"),
-                row.getLong("range_end"));
+                new BigInteger(row.getString("range_start")),
+                new BigInteger(row.getString("range_end")));
         tokenRanges.add(tokenRange);
       }
       // The table may not contain the estimates yet
@@ -302,7 +339,7 @@ protected static double getRingFraction(List<TokenRange> 
tokenRanges) {
     double ringFraction = 0;
     for (TokenRange tokenRange : tokenRanges) {
       ringFraction = ringFraction + (distance(tokenRange.rangeStart, 
tokenRange.rangeEnd)
-          .doubleValue() / TOTAL_TOKEN_COUNT.doubleValue());
+          .doubleValue() / 
SplitGenerator.getRangeSize(MURMUR3PARTITIONER).doubleValue());
     }
     return ringFraction;
   }
@@ -311,11 +348,11 @@ protected static double getRingFraction(List<TokenRange> 
tokenRanges) {
    * Measure distance between two tokens.
    */
   @VisibleForTesting
-  protected static BigInteger distance(long left, long right) {
-    if (right > left) {
-      return BigInteger.valueOf(right).subtract(BigInteger.valueOf(left));
+  protected static BigInteger distance(BigInteger left, BigInteger right) {
+    if (right.compareTo(left) > 0) {
+      return right.subtract(left);
     } else {
-      return 
BigInteger.valueOf(right).subtract(BigInteger.valueOf(left)).add(TOTAL_TOKEN_COUNT);
+      return 
right.subtract(left).add(SplitGenerator.getRangeSize(MURMUR3PARTITIONER));
     }
   }
 
@@ -324,7 +361,7 @@ protected static BigInteger distance(long left, long right) 
{
    */
   @VisibleForTesting
   protected static boolean isMurmur3Partitioner(Cluster cluster) {
-    return "org.apache.cassandra.dht.Murmur3Partitioner".equals(
+    return MURMUR3PARTITIONER.equals(
         cluster.getMetadata().getPartitioner());
   }
 
@@ -336,11 +373,11 @@ protected static boolean isMurmur3Partitioner(Cluster 
cluster) {
   protected static class TokenRange {
     private final long partitionCount;
     private final long meanPartitionSize;
-    private final long rangeStart;
-    private final long rangeEnd;
+    private final BigInteger rangeStart;
+    private final BigInteger rangeEnd;
 
     public TokenRange(
-        long partitionCount, long meanPartitionSize, long rangeStart, long
+        long partitionCount, long meanPartitionSize, BigInteger rangeStart, 
BigInteger
         rangeEnd) {
       this.partitionCount = partitionCount;
       this.meanPartitionSize = meanPartitionSize;
diff --git 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
new file mode 100644
index 00000000000..377bdc98e9c
--- /dev/null
+++ 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/RingRange.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import java.math.BigInteger;
+
+/** Models a Cassandra token range. */
+final class RingRange {
+  private final BigInteger start;
+  private final BigInteger end;
+
+  RingRange(BigInteger start, BigInteger end) {
+    this.start = start;
+    this.end = end;
+  }
+
+  BigInteger getStart() {
+    return start;
+  }
+
+  BigInteger getEnd() {
+    return end;
+  }
+
+  /**
+   * Returns the size of this range.
+   *
+   * @return size of the range, max - range, in case of wrap
+   */
+  BigInteger span(BigInteger ringSize) {
+    if (start.compareTo(end) >= 0) {
+      return end.subtract(start).add(ringSize);
+    } else {
+      return end.subtract(start);
+    }
+  }
+
+  /**
+   * @return true if 0 is inside of this range. Note that if start == end, 
then wrapping is true
+   */
+  public boolean isWrapping() {
+    return start.compareTo(end) >= 0;
+  }
+
+  @Override
+  public String toString() {
+    return String.format("(%s,%s]", start.toString(), end.toString());
+  }
+}
diff --git 
a/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
new file mode 100644
index 00000000000..5667edbb623
--- /dev/null
+++ 
b/sdks/java/io/cassandra/src/main/java/org/apache/beam/sdk/io/cassandra/SplitGenerator.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.Lists;
+import java.math.BigInteger;
+import java.util.ArrayList;
+import java.util.List;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Splits given Cassandra table's token range into splits. */
+final class SplitGenerator {
+
+  private static final Logger LOG = 
LoggerFactory.getLogger(SplitGenerator.class);
+
+  private final String partitioner;
+  private final BigInteger rangeMin;
+  private final BigInteger rangeMax;
+  private final BigInteger rangeSize;
+
+  SplitGenerator(String partitioner) {
+    rangeMin = getRangeMin(partitioner);
+    rangeMax = getRangeMax(partitioner);
+    rangeSize = getRangeSize(partitioner);
+    this.partitioner = partitioner;
+  }
+
+  private static BigInteger getRangeMin(String partitioner) {
+    if (partitioner.endsWith("RandomPartitioner")) {
+      return BigInteger.ZERO;
+    } else if (partitioner.endsWith("Murmur3Partitioner")) {
+      return new BigInteger("2").pow(63).negate();
+    } else {
+      throw new UnsupportedOperationException(
+          "Unsupported partitioner. " + "Only Random and Murmur3 are 
supported");
+    }
+  }
+
+  private static BigInteger getRangeMax(String partitioner) {
+    if (partitioner.endsWith("RandomPartitioner")) {
+      return new BigInteger("2").pow(127).subtract(BigInteger.ONE);
+    } else if (partitioner.endsWith("Murmur3Partitioner")) {
+      return new BigInteger("2").pow(63).subtract(BigInteger.ONE);
+    } else {
+      throw new UnsupportedOperationException(
+          "Unsupported partitioner. " + "Only Random and Murmur3 are 
supported");
+    }
+  }
+
+  static BigInteger getRangeSize(String partitioner) {
+    return 
getRangeMax(partitioner).subtract(getRangeMin(partitioner)).add(BigInteger.ONE);
+  }
+
+  /**
+   * Given big0 properly ordered list of tokens, compute at least {@code 
totalSplitCount} splits.
+   * Each split can contain several token ranges in order to reduce the 
overhead of vnodes.
+   * Currently, token range grouping is not smart and doesn't check if they 
share the same
+   * replicas.
+   * This is planned to change once Beam is able to handle collocation with 
the Cassandra nodes.
+   *
+   * @param totalSplitCount requested total amount of splits. This function 
may generate more
+   *     splits.
+   * @param ringTokens list of all start tokens in big0 cluster. They have to 
be in ring order.
+   * @return big0 list containing at least {@code totalSplitCount} splits.
+   */
+  List<List<RingRange>> generateSplits(long totalSplitCount, List<BigInteger> 
ringTokens) {
+    int tokenRangeCount = ringTokens.size();
+
+    List<RingRange> splits = new ArrayList<>();
+    for (int i = 0; i < tokenRangeCount; i++) {
+      BigInteger start = ringTokens.get(i);
+      BigInteger stop = ringTokens.get((i + 1) % tokenRangeCount);
+
+      if (!inRange(start) || !inRange(stop)) {
+        throw new RuntimeException(
+            String.format("Tokens (%s,%s) not in range of %s", start, stop, 
partitioner));
+      }
+      if (start.equals(stop) && tokenRangeCount != 1) {
+        throw new RuntimeException(
+            String.format("Tokens (%s,%s): two nodes have the same token", 
start, stop));
+      }
+
+      BigInteger rs = stop.subtract(start);
+      if (rs.compareTo(BigInteger.ZERO) <= 0) {
+        // wrap around case
+        rs = rs.add(rangeSize);
+      }
+
+      // the below, in essence, does this:
+      // splitCount = ceiling((rangeSize / RANGE_SIZE) * totalSplitCount)
+      BigInteger[] splitCountAndRemainder =
+          
rs.multiply(BigInteger.valueOf(totalSplitCount)).divideAndRemainder(rangeSize);
+
+      int splitCount =
+          splitCountAndRemainder[0].intValue()
+              + (splitCountAndRemainder[1].equals(BigInteger.ZERO) ? 0 : 1);
+
+      LOG.info("Dividing token range [{},{}) into {} splits", start, stop, 
splitCount);
+
+      // Make big0 list of all the endpoints for the splits, including both 
start and stop
+      List<BigInteger> endpointTokens = new ArrayList<>();
+      for (int j = 0; j <= splitCount; j++) {
+        BigInteger offset =
+            
rs.multiply(BigInteger.valueOf(j)).divide(BigInteger.valueOf(splitCount));
+        BigInteger token = start.add(offset);
+        if (token.compareTo(rangeMax) > 0) {
+          token = token.subtract(rangeSize);
+        }
+        endpointTokens.add(token);
+      }
+
+      // Append the splits between the endpoints
+      for (int j = 0; j < splitCount; j++) {
+        splits.add(new RingRange(endpointTokens.get(j), endpointTokens.get(j + 
1)));
+        LOG.debug("Split #{}: [{},{})", j + 1, endpointTokens.get(j), 
endpointTokens.get(j + 1));
+      }
+    }
+
+    BigInteger total = BigInteger.ZERO;
+    for (RingRange split : splits) {
+      BigInteger size = split.span(rangeSize);
+      total = total.add(size);
+    }
+    if (!total.equals(rangeSize)) {
+      throw new RuntimeException(
+          "Some tokens are missing from the splits. " + "This should not 
happen.");
+    }
+    return coalesceSplits(getTargetSplitSize(totalSplitCount), splits);
+  }
+
+  private boolean inRange(BigInteger token) {
+    return !(token.compareTo(rangeMin) < 0 || token.compareTo(rangeMax) > 0);
+  }
+
+  @VisibleForTesting
+  List<List<RingRange>> coalesceSplits(
+      BigInteger targetSplitSize, List<RingRange> splits) {
+
+    List<List<RingRange>> coalescedSplits = Lists.newArrayList();
+    List<RingRange> tokenRangesForCurrentSplit = Lists.newArrayList();
+    BigInteger tokenCount = BigInteger.ZERO;
+
+      for (RingRange tokenRange : splits) {
+        if 
(tokenRange.span(rangeSize).add(tokenCount).compareTo(targetSplitSize) > 0
+            && !tokenRangesForCurrentSplit.isEmpty()) {
+          // enough tokens in that segment
+          LOG.info(
+              "Got enough tokens for one split ({}) : {}",
+              tokenCount,
+              tokenRangesForCurrentSplit);
+          coalescedSplits.add(tokenRangesForCurrentSplit);
+          tokenRangesForCurrentSplit = Lists.newArrayList();
+          tokenCount = BigInteger.ZERO;
+        }
+
+        tokenCount = tokenCount.add(tokenRange.span(rangeSize));
+        tokenRangesForCurrentSplit.add(tokenRange);
+      }
+
+      if (!tokenRangesForCurrentSplit.isEmpty()) {
+        coalescedSplits.add(tokenRangesForCurrentSplit);
+      }
+
+      return coalescedSplits;
+    }
+
+  private BigInteger getTargetSplitSize(long splitCount) {
+    return 
(rangeMax.subtract(rangeMin)).divide(BigInteger.valueOf(splitCount));
+  }
+}
diff --git 
a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOIT.java
 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOIT.java
index f323bbcafa4..8e517efa595 100644
--- 
a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOIT.java
+++ 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraIOIT.java
@@ -92,6 +92,7 @@ public void testRead() throws Exception {
     PCollection<Scientist> output = 
pipeline.apply(CassandraIO.<Scientist>read()
         .withHosts(Collections.singletonList(options.getCassandraHost()))
         .withPort(options.getCassandraPort())
+        .withMinNumberOfSplits(20)
         .withKeyspace(CassandraTestDataSet.KEYSPACE)
         .withTable(CassandraTestDataSet.TABLE_READ_NAME)
         .withEntity(Scientist.class)
diff --git 
a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java
 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java
index 1b27dc2e5dc..ee1afda2ba2 100644
--- 
a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java
+++ 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/CassandraServiceImplTest.java
@@ -21,116 +21,90 @@
 import static org.junit.Assert.assertTrue;
 
 import com.datastax.driver.core.Cluster;
+import com.datastax.driver.core.ColumnMetadata;
+import com.datastax.driver.core.KeyspaceMetadata;
 import com.datastax.driver.core.Metadata;
+import com.datastax.driver.core.TableMetadata;
 import java.math.BigInteger;
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
 import org.junit.Test;
 import org.mockito.Mockito;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
-/**
- * Tests on {@link CassandraServiceImplTest}.
- */
+/** Tests on {@link CassandraServiceImplTest}. */
 public class CassandraServiceImplTest {
-
-  private static final Logger LOG = 
LoggerFactory.getLogger(CassandraServiceImplTest.class);
-
   private static final String MURMUR3_PARTITIONER = 
"org.apache.cassandra.dht.Murmur3Partitioner";
 
   private Cluster createClusterMock() {
     Metadata metadata = Mockito.mock(Metadata.class);
+    KeyspaceMetadata keyspaceMetadata = Mockito.mock(KeyspaceMetadata.class);
+    TableMetadata tableMetadata = Mockito.mock(TableMetadata.class);
+    ColumnMetadata columnMetadata = Mockito.mock(ColumnMetadata.class);
+
     Mockito.when(metadata.getPartitioner()).thenReturn(MURMUR3_PARTITIONER);
+    
Mockito.when(metadata.getKeyspace(Mockito.anyString())).thenReturn(keyspaceMetadata);
+    
Mockito.when(keyspaceMetadata.getTable(Mockito.anyString())).thenReturn(tableMetadata);
+    Mockito.when(tableMetadata.getPartitionKey())
+        .thenReturn(Collections.singletonList(columnMetadata));
+    Mockito.when(columnMetadata.getName()).thenReturn("$pk");
     Cluster cluster = Mockito.mock(Cluster.class);
     Mockito.when(cluster.getMetadata()).thenReturn(metadata);
     return cluster;
   }
 
   @Test
-  public void testValidPartitioner() throws Exception {
+  public void testValidPartitioner() {
     assertTrue(CassandraServiceImpl.isMurmur3Partitioner(createClusterMock()));
   }
 
   @Test
-  public void testDistance() throws Exception {
-    BigInteger distance = CassandraServiceImpl.distance(10L, 100L);
+  public void testDistance() {
+    BigInteger distance = CassandraServiceImpl.distance(new BigInteger("10"),
+        new BigInteger("100"));
     assertEquals(BigInteger.valueOf(90), distance);
 
-    distance = CassandraServiceImpl.distance(100L, 10L);
-    assertEquals(new BigInteger("18446744073709551525"), distance);
+    distance = CassandraServiceImpl.distance(new BigInteger("100"), new 
BigInteger("10"));
+    assertEquals(new BigInteger("18446744073709551526"), distance);
   }
 
   @Test
-  public void testRingFraction() throws Exception {
+  public void testRingFraction() {
     // simulate a first range taking "half" of the available tokens
     List<CassandraServiceImpl.TokenRange> tokenRanges = new ArrayList<>();
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1, Long.MIN_VALUE, 
0));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1,
+        BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
     assertEquals(0.5, CassandraServiceImpl.getRingFraction(tokenRanges), 0);
 
     // add a second range to cover all tokens available
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1, 0, 
Long.MAX_VALUE));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1,
+        new BigInteger("0"), BigInteger.valueOf(Long.MAX_VALUE)));
     assertEquals(1.0, CassandraServiceImpl.getRingFraction(tokenRanges), 0);
   }
 
   @Test
-  public void testEstimatedSizeBytes() throws Exception {
+  public void testEstimatedSizeBytes() {
     List<CassandraServiceImpl.TokenRange> tokenRanges = new ArrayList<>();
     // one partition containing all tokens, the size is actually the size of 
the partition
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000, 
Long.MIN_VALUE, Long.MAX_VALUE));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000,
+        BigInteger.valueOf(Long.MIN_VALUE), 
BigInteger.valueOf(Long.MAX_VALUE)));
     assertEquals(1000, 
CassandraServiceImpl.getEstimatedSizeBytes(tokenRanges));
 
     // one partition with half of the tokens, we estimate the size to the 
double of this partition
     tokenRanges = new ArrayList<>();
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000, 
Long.MIN_VALUE, 0));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000,
+        BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("0")));
     assertEquals(2000, 
CassandraServiceImpl.getEstimatedSizeBytes(tokenRanges));
 
     // we have three partitions covering all tokens, the size is the sum of 
partition size *
     // partition count
     tokenRanges = new ArrayList<>();
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000, 
Long.MIN_VALUE, -3));
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000, -2, 10000));
-    tokenRanges.add(new CassandraServiceImpl.TokenRange(2, 3000, 10001, 
Long.MAX_VALUE));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000,
+        BigInteger.valueOf(Long.MIN_VALUE), new BigInteger("-3")));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(1, 1000,
+        new BigInteger("-2"), new BigInteger("10000")));
+    tokenRanges.add(new CassandraServiceImpl.TokenRange(2, 3000,
+        new BigInteger("10001"), BigInteger.valueOf(Long.MAX_VALUE)));
     assertEquals(8000, 
CassandraServiceImpl.getEstimatedSizeBytes(tokenRanges));
   }
-
-  @Test
-  public void testThreeSplits() throws Exception {
-    CassandraServiceImpl service = new CassandraServiceImpl();
-    CassandraIO.Read spec = 
CassandraIO.read().withKeyspace("beam").withTable("test");
-    List<CassandraIO.CassandraSource> sources = service.split(spec, 50, 150);
-    assertEquals(3, sources.size());
-    assertTrue(sources.get(0).splitQuery.matches("SELECT \\* FROM beam.test 
WHERE token\\"
-        + "(\\$pk\\)<(.*)"));
-    assertTrue(sources.get(1).splitQuery.matches("SELECT \\* FROM beam.test 
WHERE token\\"
-        + "(\\$pk\\)>=(.*) AND token\\(\\$pk\\)<(.*)"));
-    assertTrue(sources.get(2).splitQuery.matches("SELECT \\* FROM beam.test 
WHERE token\\"
-        + "(\\$pk\\)>=(.*)"));
-  }
-
-  @Test
-  public void testTwoSplits() throws Exception {
-    CassandraServiceImpl service = new CassandraServiceImpl();
-    CassandraIO.Read spec = 
CassandraIO.read().withKeyspace("beam").withTable("test");
-    List<CassandraIO.CassandraSource> sources = service.split(spec, 50, 100);
-    assertEquals(2, sources.size());
-    LOG.info("TOKEN: " + ((double) Long.MAX_VALUE / 2));
-    LOG.info(sources.get(0).splitQuery);
-    LOG.info(sources.get(1).splitQuery);
-    assertEquals("SELECT * FROM beam.test WHERE token($pk)<" + ((double) 
Long.MAX_VALUE / 2) + ";",
-        sources.get(0).splitQuery);
-    assertEquals("SELECT * FROM beam.test WHERE token($pk)>=" + ((double) 
Long.MAX_VALUE / 2)
-            + ";",
-        sources.get(1).splitQuery);
-  }
-
-  @Test
-  public void testUniqueSplit() throws Exception {
-    CassandraServiceImpl service = new CassandraServiceImpl();
-    CassandraIO.Read spec = 
CassandraIO.read().withKeyspace("beam").withTable("test");
-    List<CassandraIO.CassandraSource> sources = service.split(spec, 100, 100);
-    assertEquals(1, sources.size());
-    assertEquals("SELECT * FROM beam.test;", sources.get(0).splitQuery);
-  }
-
 }
diff --git 
a/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/SplitGeneratorTest.java
 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/SplitGeneratorTest.java
new file mode 100644
index 00000000000..bfc3f05415d
--- /dev/null
+++ 
b/sdks/java/io/cassandra/src/test/java/org/apache/beam/sdk/io/cassandra/SplitGeneratorTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.cassandra;
+
+import static org.junit.Assert.assertEquals;
+
+import java.math.BigInteger;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.junit.Test;
+
+/** Tests on {@link SplitGenerator}. */
+public final class SplitGeneratorTest {
+
+  @Test
+  public void testGenerateSegments() {
+    List<BigInteger> tokens =
+        Arrays.asList(
+                "0",
+                "1",
+                "56713727820156410577229101238628035242",
+                "56713727820156410577229101238628035243",
+                "113427455640312821154458202477256070484",
+                "113427455640312821154458202477256070485")
+            .stream()
+            .map(s -> new BigInteger(s))
+            .collect(Collectors.toList());
+
+    SplitGenerator generator = new SplitGenerator("foo.bar.RandomPartitioner");
+    List<List<RingRange>> segments = generator.generateSplits(10, tokens);
+
+    assertEquals(12, segments.size());
+    assertEquals("[(0,1], (1,14178431955039102644307275309657008811]]", 
segments.get(0).toString());
+    assertEquals(
+        
"[(14178431955039102644307275309657008811,28356863910078205288614550619314017621]]",
+        segments.get(1).toString());
+    assertEquals(
+        
"[(70892159775195513221536376548285044053,85070591730234615865843651857942052863]]",
+        segments.get(5).toString());
+
+    tokens =
+        Arrays.asList(
+                "5",
+                "6",
+                "56713727820156410577229101238628035242",
+                "56713727820156410577229101238628035243",
+                "113427455640312821154458202477256070484",
+                "113427455640312821154458202477256070485")
+            .stream()
+            .map(s -> new BigInteger(s))
+            .collect(Collectors.toList());
+
+    segments = generator.generateSplits(10, tokens);
+
+    assertEquals(12, segments.size());
+    assertEquals("[(5,6], (6,14178431955039102644307275309657008815]]", 
segments.get(0).toString());
+    assertEquals(
+        
"[(70892159775195513221536376548285044053,85070591730234615865843651857942052863]]",
+        segments.get(5).toString());
+    assertEquals(
+        
"[(141784319550391026443072753096570088109,155962751505430129087380028406227096921]]",
+        segments.get(10).toString());
+  }
+
+  @Test(expected = RuntimeException.class)
+  public void testZeroSizeRange() {
+    List<String> tokenStrings =
+        Arrays.asList(
+            "0",
+            "1",
+            "56713727820156410577229101238628035242",
+            "56713727820156410577229101238628035242",
+            "113427455640312821154458202477256070484",
+            "113427455640312821154458202477256070485");
+
+    List<BigInteger> tokens =
+        tokenStrings.stream().map(s -> new 
BigInteger(s)).collect(Collectors.toList());
+
+    SplitGenerator generator = new SplitGenerator("foo.bar.RandomPartitioner");
+    generator.generateSplits(10, tokens);
+  }
+
+  @Test
+  public void testRotatedRing() {
+    List<String> tokenStrings =
+        Arrays.asList(
+            "56713727820156410577229101238628035243",
+            "113427455640312821154458202477256070484",
+            "113427455640312821154458202477256070485",
+            "5",
+            "6",
+            "56713727820156410577229101238628035242");
+
+    List<BigInteger> tokens =
+        tokenStrings.stream().map(s -> new 
BigInteger(s)).collect(Collectors.toList());
+
+    SplitGenerator generator = new SplitGenerator("foo.bar.RandomPartitioner");
+    List<List<RingRange>> segments = generator.generateSplits(5, tokens);
+    assertEquals(6, segments.size());
+    assertEquals(
+        
"[(85070591730234615865843651857942052863,113427455640312821154458202477256070484],"
+            + " 
(113427455640312821154458202477256070484,113427455640312821154458202477256070485]]",
+        segments.get(1).toString());
+    assertEquals("[(113427455640312821154458202477256070485,"
+        + "141784319550391026443072753096570088109]]", 
segments.get(2).toString());
+    assertEquals(
+        "[(141784319550391026443072753096570088109,5], (5,6]]",
+        segments.get(3).toString());
+  }
+
+  @Test(expected = RuntimeException.class)
+  public void testDisorderedRing() {
+    List<String> tokenStrings =
+        Arrays.asList(
+            "0",
+            "113427455640312821154458202477256070485",
+            "1",
+            "56713727820156410577229101238628035242",
+            "56713727820156410577229101238628035243",
+            "113427455640312821154458202477256070484");
+
+    List<BigInteger> tokens =
+        tokenStrings.stream().map(s -> new 
BigInteger(s)).collect(Collectors.toList());
+
+    SplitGenerator generator = new SplitGenerator("foo.bar.RandomPartitioner");
+    generator.generateSplits(10, tokens);
+    // Will throw an exception when concluding that the repair segments don't 
add up.
+    // This is because the tokens were supplied out of order.
+  }
+}


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


Issue Time Tracking
-------------------

    Worklog Id:     (was: 96069)
    Time Spent: 5h 40m  (was: 5.5h)

> CassandraIO.read() splitting produces invalid queries
> -----------------------------------------------------
>
>                 Key: BEAM-3485
>                 URL: https://issues.apache.org/jira/browse/BEAM-3485
>             Project: Beam
>          Issue Type: Bug
>          Components: io-java-cassandra
>            Reporter: Eugene Kirpichov
>            Assignee: Alexander Dejanovski
>            Priority: Major
>          Time Spent: 5h 40m
>  Remaining Estimate: 0h
>
> See 
> [https://stackoverflow.com/questions/48090668/how-to-increase-dataflow-read-parallelism-from-cassandra/48131264?noredirect=1#comment83548442_48131264]
> As the question author points out, the error is likely that token($pk) should 
> be token(pk). This was likely masked by BEAM-3424 and BEAM-3425, and the 
> splitting code path effectively was never invoked, and was broken from the 
> first PR - so there are likely other bugs.
> When testing this issue, we must ensure good code coverage in an IT against a 
> real Cassandra instance.



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to