Github user KyleLi1985 commented on a diff in the pull request:
https://github.com/apache/spark/pull/22893#discussion_r231838390
--- Diff: mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala ---
@@ -521,19 +521,21 @@ object MLUtils extends Logging {
* The bound doesn't need the inner product, so we can use it as a
sufficient condition to
* check quickly whether the inner product approach is accurate.
*/
- val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff *
normDiff + EPSILON)
- if (precisionBound1 < precision) {
- sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)
- } else if (v1.isInstanceOf[SparseVector] ||
v2.isInstanceOf[SparseVector]) {
- val dotValue = dot(v1, v2)
- sqDist = math.max(sumSquaredNorm - 2.0 * dotValue, 0.0)
- val precisionBound2 = EPSILON * (sumSquaredNorm + 2.0 *
math.abs(dotValue)) /
- (sqDist + EPSILON)
- if (precisionBound2 > precision) {
- sqDist = Vectors.sqdist(v1, v2)
- }
- } else {
+ if (v1.isInstanceOf[DenseVector] && v2.isInstanceOf[DenseVector]) {
sqDist = Vectors.sqdist(v1, v2)
+ } else {
+ val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff *
normDiff + EPSILON)
--- End diff --
@srowen Thanks for review, I will update the new commit and related test
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]