Repository: spark
Updated Branches:
  refs/heads/master 8f11c6116 -> 2728c3df6


[SPARK-7578] [ML] [DOC] User guide for spark.ml Normalizer, IDF, StandardScaler

Added user guide sections with code examples.
Also added small Java unit tests to test Java example in guide.

CC: mengxr

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #6127 from jkbradley/feature-guide-2 and squashes the following commits:

cd47f4b [Joseph K. Bradley] Updated based on code review
f16bcec [Joseph K. Bradley] Fixed merge issues and update Python examples print 
calls for Python 3
0a862f9 [Joseph K. Bradley] Added Normalizer, StandardScaler to ml-features 
doc, plus small Java unit tests
a21c2d6 [Joseph K. Bradley] Updated ml-features.md with IDF


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2728c3df
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2728c3df
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2728c3df

Branch: refs/heads/master
Commit: 2728c3df6690c2fcd4af3bd1c604c98ef6d509a5
Parents: 8f11c61
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Thu May 21 22:59:45 2015 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Thu May 21 22:59:45 2015 -0700

----------------------------------------------------------------------
 docs/ml-features.md                             | 224 ++++++++++++++++---
 .../spark/ml/feature/JavaHashingTFSuite.java    |  17 +-
 .../spark/ml/feature/JavaNormalizerSuite.java   |  71 ++++++
 .../ml/feature/JavaStandardScalerSuite.java     |  71 ++++++
 4 files changed, 351 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2728c3df/docs/ml-features.md
----------------------------------------------------------------------
diff --git a/docs/ml-features.md b/docs/ml-features.md
index 06f1ac1..efe9b3b 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -18,30 +18,38 @@ This section covers algorithms for working with features, 
roughly divided into t
 
 # Feature Extractors
 
-## Hashing Term-Frequency (HashingTF)
+## TF-IDF (HashingTF and IDF)
 
-`HashingTF` is a `Transformer` which takes sets of terms (e.g., `String` terms 
can be sets of words) and converts those sets into fixed-length feature vectors.
-The algorithm combines [Term Frequency 
(TF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) counts with the [hashing 
trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality 
reduction.  Please refer to the [MLlib user guide on 
TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on 
Term-Frequency.
+[Term Frequency-Inverse Document Frequency 
(TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a common text 
pre-processing step.  In Spark ML, TF-IDF is separate into two parts: TF 
(+hashing) and IDF.
 
-HashingTF is implemented in
-[HashingTF](api/scala/index.html#org.apache.spark.ml.feature.HashingTF).
-In the following code segment, we start with a set of sentences.  We split 
each sentence into words using `Tokenizer`.  For each sentence (bag of words), 
we hash it into a feature vector.  This feature vector could then be passed to 
a learning algorithm.
+**TF**: `HashingTF` is a `Transformer` which takes sets of terms and converts 
those sets into fixed-length feature vectors.  In text processing, a "set of 
terms" might be a bag of words.
+The algorithm combines Term Frequency (TF) counts with the [hashing 
trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality 
reduction.
+
+**IDF**: `IDF` is an `Estimator` which fits on a dataset and produces an 
`IDFModel`.  The `IDFModel` takes feature vectors (generally created from 
`HashingTF`) and scales each column.  Intuitively, it down-weights columns 
which appear frequently in a corpus.
+
+Please refer to the [MLlib user guide on 
TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term 
Frequency and Inverse Document Frequency.
+For API details, refer to the [HashingTF API 
docs](api/scala/index.html#org.apache.spark.ml.feature.HashingTF) and the [IDF 
API docs](api/scala/index.html#org.apache.spark.ml.feature.IDF).
+
+In the following code segment, we start with a set of sentences.  We split 
each sentence into words using `Tokenizer`.  For each sentence (bag of words), 
we use `HashingTF` to hash the sentence into a feature vector.  We use `IDF` to 
rescale the feature vectors; this generally improves performance when using 
text as features.  Our feature vectors could then be passed to a learning 
algorithm.
 
 <div class="codetabs">
 <div data-lang="scala" markdown="1">
 {% highlight scala %}
-import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
 
-val sentenceDataFrame = sqlContext.createDataFrame(Seq(
+val sentenceData = sqlContext.createDataFrame(Seq(
   (0, "Hi I heard about Spark"),
   (0, "I wish Java could use case classes"),
   (1, "Logistic regression models are neat")
 )).toDF("label", "sentence")
 val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
-val wordsDataFrame = tokenizer.transform(sentenceDataFrame)
-val hashingTF = new 
HashingTF().setInputCol("words").setOutputCol("features").setNumFeatures(20)
-val featurized = hashingTF.transform(wordsDataFrame)
-featurized.select("features", "label").take(3).foreach(println)
+val wordsData = tokenizer.transform(sentenceData)
+val hashingTF = new 
HashingTF().setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20)
+val featurizedData = hashingTF.transform(wordsData)
+val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
+val idfModel = idf.fit(featurizedData)
+val rescaledData = idfModel.transform(featurizedData)
+rescaledData.select("features", "label").take(3).foreach(println)
 {% endhighlight %}
 </div>
 
@@ -51,6 +59,7 @@ import com.google.common.collect.Lists;
 
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.ml.feature.HashingTF;
+import org.apache.spark.ml.feature.IDF;
 import org.apache.spark.ml.feature.Tokenizer;
 import org.apache.spark.mllib.linalg.Vector;
 import org.apache.spark.sql.DataFrame;
@@ -70,16 +79,19 @@ StructType schema = new StructType(new StructField[]{
   new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
   new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
 });
-DataFrame sentenceDataFrame = sqlContext.createDataFrame(jrdd, schema);
+DataFrame sentenceData = sqlContext.createDataFrame(jrdd, schema);
 Tokenizer tokenizer = new 
Tokenizer().setInputCol("sentence").setOutputCol("words");
-DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+DataFrame wordsData = tokenizer.transform(sentenceData);
 int numFeatures = 20;
 HashingTF hashingTF = new HashingTF()
   .setInputCol("words")
-  .setOutputCol("features")
+  .setOutputCol("rawFeatures")
   .setNumFeatures(numFeatures);
-DataFrame featurized = hashingTF.transform(wordsDataFrame);
-for (Row r : featurized.select("features", "label").take(3)) {
+DataFrame featurizedData = hashingTF.transform(wordsData);
+IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
+IDFModel idfModel = idf.fit(featurizedData);
+DataFrame rescaledData = idfModel.transform(featurizedData);
+for (Row r : rescaledData.select("features", "label").take(3)) {
   Vector features = r.getAs(0);
   Double label = r.getDouble(1);
   System.out.println(features);
@@ -89,19 +101,22 @@ for (Row r : featurized.select("features", 
"label").take(3)) {
 
 <div data-lang="python" markdown="1">
 {% highlight python %}
-from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.ml.feature import HashingTF, IDF, Tokenizer
 
-sentenceDataFrame = sqlContext.createDataFrame([
+sentenceData = sqlContext.createDataFrame([
   (0, "Hi I heard about Spark"),
   (0, "I wish Java could use case classes"),
   (1, "Logistic regression models are neat")
 ], ["label", "sentence"])
 tokenizer = Tokenizer(inputCol="sentence", outputCol="words")
-wordsDataFrame = tokenizer.transform(sentenceDataFrame)
-hashingTF = HashingTF(inputCol="words", outputCol="features", numFeatures=20)
-featurized = hashingTF.transform(wordsDataFrame)
-for features_label in featurized.select("features", "label").take(3):
-  print features_label
+wordsData = tokenizer.transform(sentenceData)
+hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", 
numFeatures=20)
+featurizedData = hashingTF.transform(wordsData)
+idf = IDF(inputCol="rawFeatures", outputCol="features")
+idfModel = idf.fit(featurizedData)
+rescaledData = idfModel.transform(featurizedData)
+for features_label in rescaledData.select("features", "label").take(3):
+  print(features_label)
 {% endhighlight %}
 </div>
 </div>
@@ -267,11 +282,12 @@ sentenceDataFrame = sqlContext.createDataFrame([
 tokenizer = Tokenizer(inputCol="sentence", outputCol="words")
 wordsDataFrame = tokenizer.transform(sentenceDataFrame)
 for words_label in wordsDataFrame.select("words", "label").take(3):
-  print words_label
+  print(words_label)
 {% endhighlight %}
 </div>
 </div>
 
+
 ## Binarizer
 
 Binarization is the process of thresholding numerical features to binary 
features. As some probabilistic estimators make assumption that the input data 
is distributed according to [Bernoulli 
distribution](http://en.wikipedia.org/wiki/Bernoulli_distribution), a binarizer 
is useful for pre-processing the input data with continuous numerical features.
@@ -352,7 +368,7 @@ binarizer = Binarizer(threshold=0.5, inputCol="feature", 
outputCol="binarized_fe
 binarizedDataFrame = binarizer.transform(continuousDataFrame)
 binarizedFeatures = binarizedDataFrame.select("binarized_feature")
 for binarized_feature, in binarizedFeatures.collect():
-  print binarized_feature
+  print(binarized_feature)
 {% endhighlight %}
 </div>
 </div>
@@ -618,5 +634,161 @@ indexedData = indexerModel.transform(data)
 </div>
 </div>
 
+
+## Normalizer
+
+`Normalizer` is a `Transformer` which transforms a dataset of `Vector` rows, 
normalizing each `Vector` to have unit norm.  It takes parameter `p`, which 
specifies the 
[p-norm](http://en.wikipedia.org/wiki/Norm_%28mathematics%29#p-norm) used for 
normalization.  ($p = 2$ by default.)  This normalization can help standardize 
your input data and improve the behavior of learning algorithms.
+
+The following example demonstrates how to load a dataset in libsvm format and 
then normalize each row to have unit $L^2$ norm and unit $L^\infty$ norm.
+
+<div class="codetabs">
+<div data-lang="scala">
+{% highlight scala %}
+import org.apache.spark.ml.feature.Normalizer
+import org.apache.spark.mllib.util.MLUtils
+
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+val dataFrame = sqlContext.createDataFrame(data)
+
+// Normalize each Vector using $L^1$ norm.
+val normalizer = new Normalizer()
+  .setInputCol("features")
+  .setOutputCol("normFeatures")
+  .setP(1.0)
+val l1NormData = normalizer.transform(dataFrame)
+
+// Normalize each Vector using $L^\infty$ norm.
+val lInfNormData = normalizer.transform(dataFrame, normalizer.p -> 
Double.PositiveInfinity)
+{% endhighlight %}
+</div>
+
+<div data-lang="java">
+{% highlight java %}
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.ml.feature.Normalizer;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+
+JavaRDD<LabeledPoint> data =
+  MLUtils.loadLibSVMFile(jsc.sc(), 
"data/mllib/sample_libsvm_data.txt").toJavaRDD();
+DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
+
+// Normalize each Vector using $L^1$ norm.
+Normalizer normalizer = new Normalizer()
+  .setInputCol("features")
+  .setOutputCol("normFeatures")
+  .setP(1.0);
+DataFrame l1NormData = normalizer.transform(dataFrame);
+
+// Normalize each Vector using $L^\infty$ norm.
+DataFrame lInfNormData =
+  normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
+{% endhighlight %}
+</div>
+
+<div data-lang="python">
+{% highlight python %}
+from pyspark.mllib.util import MLUtils
+from pyspark.ml.feature import Normalizer
+
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+dataFrame = sqlContext.createDataFrame(data)
+
+# Normalize each Vector using $L^1$ norm.
+normalizer = Normalizer(inputCol="features", outputCol="normFeatures", p=1.0)
+l1NormData = normalizer.transform(dataFrame)
+
+# Normalize each Vector using $L^\infty$ norm.
+lInfNormData = normalizer.transform(dataFrame, {normalizer.p: float("inf")})
+{% endhighlight %}
+</div>
+</div>
+
+
+## StandardScaler
+
+`StandardScaler` transforms a dataset of `Vector` rows, normalizing each 
feature to have unit standard deviation and/or zero mean.  It takes parameters:
+
+* `withStd`: True by default. Scales the data to unit standard deviation.
+* `withMean`: False by default. Centers the data with mean before scaling. It 
will build a dense output, so this does not work on sparse input and will raise 
an exception.
+
+`StandardScaler` is a `Model` which can be `fit` on a dataset to produce a 
`StandardScalerModel`; this amounts to computing summary statistics.  The model 
can then transform a `Vector` column in a dataset to have unit standard 
deviation and/or zero mean features.
+
+Note that if the standard deviation of a feature is zero, it will return 
default `0.0` value in the `Vector` for that feature.
+
+More details can be found in the API docs for
+[StandardScaler](api/scala/index.html#org.apache.spark.ml.feature.StandardScaler)
 and
+[StandardScalerModel](api/scala/index.html#org.apache.spark.ml.feature.StandardScalerModel).
+
+The following example demonstrates how to load a dataset in libsvm format and 
then normalize each feature to have unit standard deviation.
+
+<div class="codetabs">
+<div data-lang="scala">
+{% highlight scala %}
+import org.apache.spark.ml.feature.StandardScaler
+import org.apache.spark.mllib.util.MLUtils
+
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+val dataFrame = sqlContext.createDataFrame(data)
+val scaler = new StandardScaler()
+  .setInputCol("features")
+  .setOutputCol("scaledFeatures")
+  .setWithStd(true)
+  .setWithMean(false)
+
+// Compute summary statistics by fitting the StandardScaler
+val scalerModel = scaler.fit(dataFrame)
+
+// Normalize each feature to have unit standard deviation.
+val scaledData = scalerModel.transform(dataFrame)
+{% endhighlight %}
+</div>
+
+<div data-lang="java">
+{% highlight java %}
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.ml.feature.StandardScaler;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.DataFrame;
+
+JavaRDD<LabeledPoint> data =
+  MLUtils.loadLibSVMFile(jsc.sc(), 
"data/mllib/sample_libsvm_data.txt").toJavaRDD();
+DataFrame dataFrame = jsql.createDataFrame(data, LabeledPoint.class);
+StandardScaler scaler = new StandardScaler()
+  .setInputCol("features")
+  .setOutputCol("scaledFeatures")
+  .setWithStd(true)
+  .setWithMean(false);
+
+// Compute summary statistics by fitting the StandardScaler
+StandardScalerModel scalerModel = scaler.fit(dataFrame);
+
+// Normalize each feature to have unit standard deviation.
+DataFrame scaledData = scalerModel.transform(dataFrame);
+{% endhighlight %}
+</div>
+
+<div data-lang="python">
+{% highlight python %}
+from pyspark.mllib.util import MLUtils
+from pyspark.ml.feature import StandardScaler
+
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+dataFrame = sqlContext.createDataFrame(data)
+scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures",
+                        withStd=True, withMean=False)
+
+# Compute summary statistics by fitting the StandardScaler
+scalerModel = scaler.fit(dataFrame)
+
+# Normalize each feature to have unit standard deviation.
+scaledData = scalerModel.transform(dataFrame)
+{% endhighlight %}
+</div>
+</div>
+
+
 # Feature Selectors
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2728c3df/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 23463ab..da22180 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -63,17 +63,22 @@ public class JavaHashingTFSuite {
       new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
       new StructField("sentence", DataTypes.StringType, false, 
Metadata.empty())
     });
-    DataFrame sentenceDataFrame = jsql.createDataFrame(jrdd, schema);
 
-    Tokenizer tokenizer = new 
Tokenizer().setInputCol("sentence").setOutputCol("words");
-    DataFrame wordsDataFrame = tokenizer.transform(sentenceDataFrame);
+    DataFrame sentenceData = jsql.createDataFrame(jrdd, schema);
+    Tokenizer tokenizer = new Tokenizer()
+      .setInputCol("sentence")
+      .setOutputCol("words");
+    DataFrame wordsData = tokenizer.transform(sentenceData);
     int numFeatures = 20;
     HashingTF hashingTF = new HashingTF()
       .setInputCol("words")
-      .setOutputCol("features")
+      .setOutputCol("rawFeatures")
       .setNumFeatures(numFeatures);
-    DataFrame featurized = hashingTF.transform(wordsDataFrame);
-    for (Row r : featurized.select("features", "words", "label").take(3)) {
+    DataFrame featurizedData = hashingTF.transform(wordsData);
+    IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
+    IDFModel idfModel = idf.fit(featurizedData);
+    DataFrame rescaledData = idfModel.transform(featurizedData);
+    for (Row r : rescaledData.select("features", "label").take(3)) {
       Vector features = r.getAs(0);
       Assert.assertEquals(features.size(), numFeatures);
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2728c3df/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
new file mode 100644
index 0000000..d82f3b7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaNormalizerSuite {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaNormalizerSuite");
+    jsql = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void normalizer() {
+    // The tests are to check Java compatibility.
+    List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList(
+      new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
+      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
+      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
+    );
+    DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+      VectorIndexerSuite.FeatureData.class);
+    Normalizer normalizer = new Normalizer()
+      .setInputCol("features")
+      .setOutputCol("normFeatures");
+
+    // Normalize each Vector using $L^2$ norm.
+    DataFrame l2NormData = normalizer.transform(dataFrame, 
normalizer.p().w(2));
+    l2NormData.count();
+
+    // Normalize each Vector using $L^\infty$ norm.
+    DataFrame lInfNormData =
+      normalizer.transform(dataFrame, 
normalizer.p().w(Double.POSITIVE_INFINITY));
+    lInfNormData.count();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2728c3df/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
new file mode 100644
index 0000000..74eb273
--- /dev/null
+++ 
b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaStandardScalerSuite {
+  private transient JavaSparkContext jsc;
+  private transient SQLContext jsql;
+
+  @Before
+  public void setUp() {
+    jsc = new JavaSparkContext("local", "JavaStandardScalerSuite");
+    jsql = new SQLContext(jsc);
+  }
+
+  @After
+  public void tearDown() {
+    jsc.stop();
+    jsc = null;
+  }
+
+  @Test
+  public void standardScaler() {
+    // The tests are to check Java compatibility.
+    List<VectorIndexerSuite.FeatureData> points = Lists.newArrayList(
+      new VectorIndexerSuite.FeatureData(Vectors.dense(0.0, -2.0)),
+      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
+      new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
+    );
+    DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+      VectorIndexerSuite.FeatureData.class);
+    StandardScaler scaler = new StandardScaler()
+      .setInputCol("features")
+      .setOutputCol("scaledFeatures")
+      .setWithStd(true)
+      .setWithMean(false);
+
+    // Compute summary statistics by fitting the StandardScaler
+    StandardScalerModel scalerModel = scaler.fit(dataFrame);
+
+    // Normalize each feature to have unit standard deviation.
+    DataFrame scaledData = scalerModel.transform(dataFrame);
+    scaledData.count();
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to