[
https://issues.apache.org/jira/browse/SPARK-2104?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=14045504#comment-14045504
]
Reynold Xin commented on SPARK-2104:
------------------------------------
BTW I have some old code I wrote -- you can do your changes based on this
{code}
/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range
into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD
passed in.
*
* Note that the actual number of partitions created by the RangePartitioner
might not be the same
* as the `partitions` parameter, in the case where the number of sampled
records is less than
* the value of `partitions`.
*/
class RangePartitioner[K : Ordering : ClassTag, V](
var partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
extends Partitioner {
private var ordering = implicitly[Ordering[K]]
// An array of upper bounds for the first (partitions - 1) partitions
var rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sorted
if (rddSample.length == 0) {
Array()
} else {
val bounds = new Array[K](partitions - 1)
for (i <- 0 until partitions - 1) {
val index = (rddSample.length - 1) * (i + 1) / partitions
bounds(i) = rddSample(index)
}
bounds
}
}
}
@throws(classOf[IOException])
private def writeObject(out: ObjectOutputStream): Unit = {
val sfactory = SparkEnv.get.serializer
// Treat java serializer with default action rather than going thru
serialization, to avoid a
// separate serialization header.
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
case _ =>
out.writeInt(partitions)
val ser = sfactory.newInstance()
Utils.serializeViaNestedStream(out, ser) { stream =>
stream.writeObject(ordering)
stream.writeObject(scala.reflect.classTag[K])
stream.writeObject(rangeBounds)
}
}
}
@throws(classOf[IOException])
private def readObject(in: ObjectInputStream): Unit = {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
case _ =>
partitions = in.readInt()
val ser = sfactory.newInstance()
Utils.deserializeViaNestedStream(in, ser) { ds =>
println(ds)
ordering = ds.readObject[Ordering[K]]()
implicit val classTag = ds.readObject[ClassTag[Array[K]]]()
rangeBounds = ds.readObject[Array[K]]()(classTag)
binarySearch = CollectionsUtils.makeBinarySearch[K]
}
}
}
def numPartitions = rangeBounds.length + 1
private var binarySearch: ((Array[K], K) => Int) =
CollectionsUtils.makeBinarySearch[K]
def getPartition(key: Any): Int = {
val k = key.asInstanceOf[K]
var partition = 0
if (rangeBounds.length < 1000) {
// If we have less than 100 partitions naive search
while (partition < rangeBounds.length && ordering.gt(k,
rangeBounds(partition))) {
partition += 1
}
} else {
// Determine which binary search method to use only once.
partition = binarySearch(rangeBounds, k)
// binarySearch either returns the match location or -[insertion point]-1
if (partition < 0) {
partition = -partition-1
}
if (partition > rangeBounds.length) {
partition = rangeBounds.length
}
}
if (ascending) {
partition
} else {
rangeBounds.length - partition
}
}
override def equals(other: Any): Boolean = other match {
case r: RangePartitioner[_,_] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>
false
}
override def hashCode(): Int = {
val prime = 31
var result = 1
var i = 0
while (i < rangeBounds.length) {
result = prime * result + rangeBounds(i).hashCode
i += 1
}
result = prime * result + ascending.hashCode
result
}
}
{code}
> RangePartitioner should use user specified serializer to serialize range
> bounds
> -------------------------------------------------------------------------------
>
> Key: SPARK-2104
> URL: https://issues.apache.org/jira/browse/SPARK-2104
> Project: Spark
> Issue Type: Bug
> Reporter: Reynold Xin
>
> Otherwise it is pretty annoying to do a sort on types that are not java
> serializable.
> To reproduce, just set the serializer to Kryo, and run the following job:
> {code}
> class JavaNonSerializableClass extends Comparable { override def compareTo(o:
> JavaNonSerializableClass) = 0 }
> sc.parallelize(Seq(new JavaNonSerializableClass, new
> JavaNonSerializableClass), 2).map(x => (x,x)).sortByKey()
> {code}
> Basically the partitioner will always be serialized using Java (by the task
> closure serializer). However, the rangeBounds variable in RangePartitioner
> should be serialized with the user specified serializer.
--
This message was sent by Atlassian JIRA
(v6.2#6252)