Github user holdenk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21942#discussion_r216022250
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala ---
    @@ -160,15 +160,88 @@ class StandardScalerModel private[ml] (
       @Since("2.0.0")
       override def transform(dataset: Dataset[_]): DataFrame = {
         transformSchema(dataset.schema, logging = true)
    -    val scaler = new feature.StandardScalerModel(std, mean, $(withStd), 
$(withMean))
    -
    -    // TODO: Make the transformer natively in ml framework to avoid extra 
conversion.
    -    val transformer: Vector => Vector = v => 
scaler.transform(OldVectors.fromML(v)).asML
    +    val transformer: Vector => Vector = v => transform(v)
     
         val scale = udf(transformer)
         dataset.withColumn($(outputCol), scale(col($(inputCol))))
       }
     
    +  /**
    +   * Since `shift` will be only used in `withMean` branch, we have it as
    +   * `lazy val` so it will be evaluated in that branch. Note that we don't
    +   * want to create this array multiple times in `transform` function.
    +   */
    +  private lazy val shift: Array[Double] = mean.toArray
    +
    +   /**
    +    * Applies standardization transformation on a vector.
    +    *
    +    * @param vector Vector to be standardized.
    +    * @return Standardized vector. If the std of a column is zero, it will 
return default `0.0`
    +    *         for the column with zero std.
    +    */
    +  private[spark] def transform(vector: Vector): Vector = {
    +    require(mean.size == vector.size)
    +    if ($(withMean)) {
    +      /**
    +       * By default, Scala generates Java methods for member variables. So 
every time
    +       * member variables are accessed, `invokespecial` is called. This is 
an expensive
    +       * operation, and can be avoided by having a local reference of 
`shift`.
    +       */
    +      val localShift = shift
    +      /**  Must have a copy of the values since they will be modified in 
place. */
    +      val values = vector match {
    +        /** Handle DenseVector specially because its `toArray` method does 
not clone values. */
    +        case d: DenseVector => d.values.clone()
    +        case v: Vector => v.toArray
    +      }
    +      val size = values.length
    +      if ($(withStd)) {
    +        var i = 0
    +        while (i < size) {
    +          values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * 
(1.0 / std(i)) else 0.0
    +          i += 1
    +        }
    +      } else {
    +        var i = 0
    +        while (i < size) {
    +          values(i) -= localShift(i)
    +          i += 1
    +        }
    +      }
    +      Vectors.dense(values)
    +    } else if ($(withStd)) {
    --- End diff --
    
    Maybe leave a comment withStd and not mean since when tracing the code by 
hand the nested if/else if can get a bit confusing flow wise.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to