zhengruifeng commented on a change in pull request #26550:
[WIP][SPARK-29334][ML][MLLIB] Support for basic vector operators
URL: https://github.com/apache/spark/pull/26550#discussion_r353057261
##########
File path: mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala
##########
@@ -753,6 +799,24 @@ class SparseVector @Since("2.0.0") (
}.unzip
new SparseVector(selectedIndices.length, sliceInds.toArray,
sliceVals.toArray)
}
+
+ def +(v: Vector): Vector = {
+ require(this.size == v.size)
+ v match {
+ case dv: DenseVector =>
+ val arr = dv.values.clone
+ for (i <- this.indices) arr(i) += this(i)
+ Vectors.dense(arr)
+ case sv: SparseVector =>
Review comment:
Current impl: 1, needs distinct and sort, 2, complexity = O(nnz * log(nnz))
It maybe faster to add two `SparseVector` in this way:
```scala
def add(v: SparseVector): SparseVector = {
require(size == v.size)
if (v.indices.isEmpty) {
copy
} else if (indices.isEmpty) {
v.copy
} else {
val thisIndices = indices
val thatIndices = v.indices
val thisValues = values
val thatValues = v.values
val indicesBuilder = mutable.ArrayBuilder.make[Int]
val valuesBuilder = mutable.ArrayBuilder.make[Double]
var thisCur = 0
var thatCur = 0
while (thisCur < thisIndices.length && thatCur < thatIndices.length) {
val thisIndex = thisIndices(thisCur)
val thatIndex = thatIndices(thatCur)
val (index, value) = if (thisIndex == thatIndex) {
val sum = thisValues(thisCur) + thatValues(thatCur)
thisCur += 1
thatCur += 1
(thisIndex, sum)
} else if (thisIndex < thatIndex) {
val sum = thisValues(thisCur)
thisCur += 1
(thisIndex, sum)
} else {
val sum = thatValues(thatCur)
thatCur += 1
(thatIndex, sum)
}
if (value != 0.0) {
indicesBuilder += index
valuesBuilder += value
}
}
if (thisCur < thisIndices.length) {
while (thisCur < thisIndices.length) {
val index = thisIndices(thisCur)
val value = thisValues(thisCur)
if (value != 0.0) {
indicesBuilder += index
valuesBuilder += value
}
thisCur += 1
}
} else if (thatCur < thatIndices.length) {
while (thatCur < thatIndices.length) {
val index = thatIndices(thatCur)
val value = thatValues(thatCur)
if (value != 0.0) {
indicesBuilder += index
valuesBuilder += value
}
thatCur += 1
}
}
new SparseVector(size, indicesBuilder.result(), valuesBuilder.result())
}
}
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]