Github user hhbyyh commented on a diff in the pull request:
https://github.com/apache/spark/pull/21942#discussion_r207103066
--- Diff:
mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala ---
@@ -160,15 +160,89 @@ class StandardScalerModel private[ml] (
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val scaler = new feature.StandardScalerModel(std, mean, $(withStd),
$(withMean))
-
- // TODO: Make the transformer natively in ml framework to avoid extra
conversion.
- val transformer: Vector => Vector = v =>
scaler.transform(OldVectors.fromML(v)).asML
+ val transformer: Vector => Vector = v => transform(v)
val scale = udf(transformer)
dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
+ /**
+ * Since `shift` will be only used in `withMean` branch, we have it as
+ * `lazy val` so it will be evaluated in that branch. Note that we don't
+ * want to create this array multiple times in `transform` function.
+ */
+ private lazy val shift: Array[Double] = mean.toArray
+
+ /**
+ * Applies standardization transformation on a vector.
+ *
+ * @param vector Vector to be standardized.
+ * @return Standardized vector. If the std of a column is zero, it will
return default `0.0`
+ * for the column with zero std.
+ */
+ @Since("2.3.0")
+ def transform(vector: Vector): Vector = {
--- End diff --
private[spark]?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]