srowen commented on a change in pull request #30548:
URL: https://github.com/apache/spark/pull/30548#discussion_r532684049
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
##########
@@ -278,34 +279,45 @@ class Word2VecModel private[ml] (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ private var bcModel: Broadcast[Word2VecModel] = _
Review comment:
I don't suppose we have a way to clean this up after use - will just
have to get GCed?
##########
File path: mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
##########
@@ -538,9 +538,13 @@ class Word2VecModel private[spark] (
@Since("1.1.0")
def transform(word: String): Vector = {
wordIndex.get(word) match {
- case Some(ind) =>
- val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize +
vectorSize)
- Vectors.dense(vec.map(_.toDouble))
+ case Some(index) =>
+ val size = vectorSize
+ val offset = index * size
+ val array = Array.ofDim[Double](size)
+ var i = 0
+ while (i < size) { array(i) = wordVectors(offset + i); i += 1 }
Review comment:
Is this actually more efficient than slice? Likewise above.
##########
File path: mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
##########
@@ -502,19 +502,19 @@ class Word2VecModel private[spark] (
private val vectorSize = wordVectors.length / numWords
// wordList: Ordered list of words obtained from wordIndex.
- private val wordList: Array[String] = {
- val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
- wl.toArray
+ private lazy val wordList: Array[String] = {
+ wordIndex.toSeq.sortBy(_._2).iterator.map(_._1).toArray
}
// wordVecNorms: Array of length numWords, each value being the Euclidean
norm
// of the wordVector.
- private val wordVecNorms: Array[Float] = {
- val wordVecNorms = new Array[Float](numWords)
+ private lazy val wordVecNorms: Array[Float] = {
Review comment:
How much does this save, if it only happens once and has to happen to
use the model?
##########
File path: mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
##########
@@ -278,34 +279,45 @@ class Word2VecModel private[ml] (
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ private var bcModel: Broadcast[Word2VecModel] = _
+
/**
* Transform a sentence column to a vector column to represent the whole
sentence. The transform
* is performed by averaging all word vectors it contains.
*/
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val outputSchema = transformSchema(dataset.schema, logging = true)
- val vectors = wordVectors.getVectors
- .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
- .map(identity).toMap // mapValues doesn't return a serializable map
(SI-7005)
- val bVectors = dataset.sparkSession.sparkContext.broadcast(vectors)
- val d = $(vectorSize)
- val emptyVec = Vectors.sparse(d, Array.emptyIntArray,
Array.emptyDoubleArray)
- val word2Vec = udf { sentence: Seq[String] =>
+
+ if (bcModel == null) {
+ bcModel = dataset.sparkSession.sparkContext.broadcast(this)
Review comment:
Looks like you only use this.wordVectors below? maybe just broadcast that
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]