zhengruifeng commented on issue #28206: [SPARK-31436][ML] MinHash keyDistance optimization URL: https://github.com/apache/spark/pull/28206#issuecomment-612868575 testcode: ``` import scala.util.Random import org.apache.spark.ml.linalg._ val rng = new Random(42) val vec1 = Vectors.dense(Array.fill(10000)(1.0).map(_.toDouble)) val vec2 = Vectors.sparse(10000, rng.shuffle(Seq.range(0, 100)).take(10000).toArray.sorted, Array.fill(100)(1.0)) val vec3 = Vectors.sparse(10000, rng.shuffle(Seq.range(0, 100)).take(10000).toArray.sorted, Array.fill(100)(1.0)) def getNonZeroIterator(vec: Vector): Iterator[(Int, Double)] = { vec match { case DenseVector(values) => Iterator.tabulate(values.length)(i => (i, values(i))).filter(_._2 != 0) case SparseVector(_, indices, values) => Iterator.tabulate(indices.length)(i => (indices(i), values(i))).filter(_._2 != 0) } } def keyDistance1(x: Vector, y: Vector): Double = { val xSet = getNonZeroIterator(x).map(_._1).toSet val ySet = getNonZeroIterator(y).map(_._1).toSet val intersectionSize = xSet.intersect(ySet).size.toDouble val unionSize = xSet.size + ySet.size - intersectionSize assert(unionSize > 0, "The union of two input sets must have at least 1 elements") 1 - intersectionSize / unionSize } def keyDistance2(x: Vector, y: Vector): Double = { val xIter = getNonZeroIterator(x).map(_._1) val yIter = getNonZeroIterator(y).map(_._1) if (xIter.isEmpty) { assert(yIter.hasNext, "The union of two input sets must have at least 1 elements") return 0.0 } else if (yIter.isEmpty) { return 0.0 } var xIndex = xIter.next var yIndex = yIter.next var xSize = 1 var ySize = 1 var intersectionSize = 0 while (xIndex != -1 || yIndex != -1) { if (xIndex != -1 && yIndex != -1) { if (xIndex == yIndex) { intersectionSize += 1 xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1 yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1 } else if (xIndex > yIndex) { yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1 } else { xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1 } } else if (xIndex != -1) { while (xIter.hasNext) { xIndex = xIter.next; xSize += 1 } xIndex = -1 } else { while (yIter.hasNext) { yIndex = yIter.next; ySize += 1 } yIndex = -1 } } val unionSize = xSize + ySize - intersectionSize assert(unionSize > 0, "The union of two input sets must have at least 1 elements") 1 - intersectionSize.toDouble / unionSize } ``` results: ``` scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec1, vec1) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778279745 end: Long = 1586778324648 duration: Long = 44903 scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec1, vec1) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778402039 end: Long = 1586778406977 duration: Long = 4938 scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec1, vec2) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778414223 end: Long = 1586778432697 duration: Long = 18474 scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec1, vec2) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778439978 end: Long = 1586778442346 duration: Long = 2368 scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec2, vec3) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778451556 end: Long = 1586778451851 duration: Long = 295 scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec2, vec3) }; val end = System.currentTimeMillis; val duration = end - start; start: Long = 1586778458768 end: Long = 1586778458821 duration: Long = 53 ``` if both vectors are dense, new impl is 9.09x faster; if both vectors are sparse, new impl is 5.66x faster; if one is dense and the other is sparse, new impl is 7.8x faster;
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
