Updated Branches:
  refs/heads/master ba38d9892 -> b0dab1bb9

Merge pull request #571 from holdenk/switchtobinarysearch.

SPARK-1072 Use binary search when needed in RangePartioner

Author: Holden Karau <hol...@pigscanfly.ca>

Closes #571 and squashes the following commits:

f31a2e1 [Holden Karau] Swith to using CollectionsUtils in Partitioner
4c7a0c3 [Holden Karau] Add CollectionsUtil as suggested by aarondav
7099962 [Holden Karau] Add the binary search to only init once
1bef01d [Holden Karau] CR feedback
a21e097 [Holden Karau] Use binary search if we have more than 1000 elements 
inside of RangePartitioner


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/b0dab1bb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/b0dab1bb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/b0dab1bb

Branch: refs/heads/master
Commit: b0dab1bb9f4cfacae68b106a44d9b14f6bea3d29
Parents: ba38d98
Author: Holden Karau <hol...@pigscanfly.ca>
Authored: Tue Feb 11 14:48:59 2014 -0800
Committer: Reynold Xin <r...@apache.org>
Committed: Tue Feb 11 14:48:59 2014 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/Partitioner.scala    | 21 +++++++--
 .../org/apache/spark/util/CollectionsUtil.scala | 46 ++++++++++++++++++++
 .../org/apache/spark/PartitioningSuite.scala    | 29 +++++++++++-
 3 files changed, 91 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/b0dab1bb/core/src/main/scala/org/apache/spark/Partitioner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala 
b/core/src/main/scala/org/apache/spark/Partitioner.scala
index cfba43d..ad99882 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -20,6 +20,7 @@ package org.apache.spark
 import scala.reflect.ClassTag
 
 import org.apache.spark.rdd.RDD
+import org.apache.spark.util.CollectionsUtils
 import org.apache.spark.util.Utils
 
 /**
@@ -118,12 +119,26 @@ class RangePartitioner[K <% Ordered[K]: ClassTag, V](
 
   def numPartitions = partitions
 
+  private val binarySearch: ((Array[K], K) => Int) = 
CollectionsUtils.makeBinarySearch[K]
+
   def getPartition(key: Any): Int = {
-    // TODO: Use a binary search here if number of partitions is large
     val k = key.asInstanceOf[K]
     var partition = 0
-    while (partition < rangeBounds.length && k > rangeBounds(partition)) {
-      partition += 1
+    if (rangeBounds.length < 1000) {
+      // If we have less than 100 partitions naive search
+      while (partition < rangeBounds.length && k > rangeBounds(partition)) {
+        partition += 1
+      }
+    } else {
+      // Determine which binary search method to use only once.
+      partition = binarySearch(rangeBounds, k)
+      // binarySearch either returns the match location or -[insertion point]-1
+      if (partition < 0) {
+        partition = -partition-1
+      }
+      if (partition > rangeBounds.length) {
+        partition = rangeBounds.length
+      }
     }
     if (ascending) {
       partition

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/b0dab1bb/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala 
b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
new file mode 100644
index 0000000..db3db87
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util
+
+import scala.Array
+import scala.reflect._
+
+object CollectionsUtils {
+  def makeBinarySearch[K <% Ordered[K] : ClassTag] : (Array[K], K) => Int = {
+    classTag[K] match {
+      case ClassTag.Float =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Float]], 
x.asInstanceOf[Float])
+      case ClassTag.Double =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Double]], 
x.asInstanceOf[Double])
+      case ClassTag.Byte =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Byte]], 
x.asInstanceOf[Byte])
+      case ClassTag.Char =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Char]], 
x.asInstanceOf[Char])
+      case ClassTag.Short =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Short]], 
x.asInstanceOf[Short])
+      case ClassTag.Int =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Int]], 
x.asInstanceOf[Int])
+      case ClassTag.Long =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[Long]], 
x.asInstanceOf[Long])
+      case _ =>
+        (l, x) => util.Arrays.binarySearch(l.asInstanceOf[Array[AnyRef]], x)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/b0dab1bb/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala 
b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 1374d01..1c5d5ea 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark
 import scala.math.abs
 import scala.collection.mutable.ArrayBuffer
 
-import org.scalatest.FunSuite
+import org.scalatest.{FunSuite, PrivateMethodTester}
 
 import org.apache.spark.SparkContext._
 import org.apache.spark.util.StatCounter
 import org.apache.spark.rdd.RDD
 
-class PartitioningSuite extends FunSuite with SharedSparkContext {
+class PartitioningSuite extends FunSuite with SharedSparkContext with 
PrivateMethodTester {
 
   test("HashPartitioner equality") {
     val p2 = new HashPartitioner(2)
@@ -67,6 +67,31 @@ class PartitioningSuite extends FunSuite with 
SharedSparkContext {
     assert(descendingP4 != p4)
   }
 
+  test("RangePartitioner getPartition") {
+    val rdd = sc.parallelize(1.to(2000)).map(x => (x, x))
+    // We have different behaviour of getPartition for partitions with less 
than 1000 and more than
+    // 1000 partitions.
+    val partitionSizes = List(1, 2, 10, 100, 500, 1000, 1500)
+    val partitioners = partitionSizes.map(p => (p, new RangePartitioner(p, 
rdd)))
+    val decoratedRangeBounds = PrivateMethod[Array[Int]]('rangeBounds)
+    partitioners.map { case (numPartitions, partitioner) =>
+      val rangeBounds = partitioner.invokePrivate(decoratedRangeBounds())
+      1.to(1000).map { element => {
+        val partition = partitioner.getPartition(element)
+        if (numPartitions > 1) {
+          if (partition < rangeBounds.size) {
+            assert(element <= rangeBounds(partition))
+          }
+          if (partition > 0) {
+            assert(element > rangeBounds(partition - 1))
+          }
+        } else {
+          assert(partition === 0)
+        }
+      }}
+    }
+  }
+
   test("HashPartitioner not equal to RangePartitioner") {
     val rdd = sc.parallelize(1 to 10).map(x => (x, x))
     val rangeP2 = new RangePartitioner(2, rdd)

Reply via email to