[SPARK-15037][SQL][MLLIB] Use SparkSession instead of SQLContext in Scala/Java TestSuites
## What changes were proposed in this pull request? Use SparkSession instead of SQLContext in Scala/Java TestSuites as this PR already very big working Python TestSuites in a diff PR. ## How was this patch tested? Existing tests Author: Sandeep Singh <[email protected]> Closes #12907 from techaddict/SPARK-15037. (cherry picked from commit ed0b4070fb50054b1ecf66ff6c32458a4967dfd3) Signed-off-by: Andrew Or <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5bf74b44 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5bf74b44 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5bf74b44 Branch: refs/heads/branch-2.0 Commit: 5bf74b44d9efcb8b0f0c3e7d129bc5ba31419551 Parents: 19a9c23 Author: Sandeep Singh <[email protected]> Authored: Tue May 10 11:17:47 2016 -0700 Committer: Andrew Or <[email protected]> Committed: Tue May 10 11:17:58 2016 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/JavaPipelineSuite.java | 27 +- .../spark/ml/attribute/JavaAttributeSuite.java | 2 +- .../JavaDecisionTreeClassifierSuite.java | 23 +- .../classification/JavaGBTClassifierSuite.java | 18 +- .../JavaLogisticRegressionSuite.java | 49 ++-- ...JavaMultilayerPerceptronClassifierSuite.java | 36 +-- .../ml/classification/JavaNaiveBayesSuite.java | 18 +- .../ml/classification/JavaOneVsRestSuite.java | 90 +++--- .../JavaRandomForestClassifierSuite.java | 26 +- .../spark/ml/clustering/JavaKMeansSuite.java | 26 +- .../spark/ml/feature/JavaBucketizerSuite.java | 20 +- .../apache/spark/ml/feature/JavaDCTSuite.java | 21 +- .../spark/ml/feature/JavaHashingTFSuite.java | 18 +- .../spark/ml/feature/JavaNormalizerSuite.java | 19 +- .../apache/spark/ml/feature/JavaPCASuite.java | 21 +- .../feature/JavaPolynomialExpansionSuite.java | 19 +- .../ml/feature/JavaStandardScalerSuite.java | 17 +- .../ml/feature/JavaStopWordsRemoverSuite.java | 22 +- .../ml/feature/JavaStringIndexerSuite.java | 26 +- .../spark/ml/feature/JavaTokenizerSuite.java | 21 +- .../ml/feature/JavaVectorAssemblerSuite.java | 31 ++- .../ml/feature/JavaVectorIndexerSuite.java | 18 +- .../spark/ml/feature/JavaVectorSlicerSuite.java | 18 +- .../spark/ml/feature/JavaWord2VecSuite.java | 22 +- .../apache/spark/ml/param/JavaParamsSuite.java | 14 +- .../apache/spark/ml/param/JavaTestParams.java | 38 ++- .../JavaDecisionTreeRegressorSuite.java | 18 +- .../ml/regression/JavaGBTRegressorSuite.java | 18 +- .../regression/JavaLinearRegressionSuite.java | 25 +- .../JavaRandomForestRegressorSuite.java | 28 +- .../source/libsvm/JavaLibSVMRelationSuite.java | 18 +- .../ml/tuning/JavaCrossValidatorSuite.java | 18 +- .../spark/ml/util/IdentifiableSuite.scala | 1 + .../ml/util/JavaDefaultReadWriteSuite.java | 21 +- .../JavaLogisticRegressionSuite.java | 35 ++- .../classification/JavaNaiveBayesSuite.java | 25 +- .../mllib/classification/JavaSVMSuite.java | 32 ++- .../clustering/JavaBisectingKMeansSuite.java | 27 +- .../clustering/JavaGaussianMixtureSuite.java | 20 +- .../spark/mllib/clustering/JavaKMeansSuite.java | 23 +- .../spark/mllib/clustering/JavaLDASuite.java | 37 +-- .../clustering/JavaStreamingKMeansSuite.java | 3 +- .../evaluation/JavaRankingMetricsSuite.java | 21 +- .../spark/mllib/feature/JavaTfIdfSuite.java | 22 +- .../spark/mllib/feature/JavaWord2VecSuite.java | 19 +- .../mllib/fpm/JavaAssociationRulesSuite.java | 23 +- .../spark/mllib/fpm/JavaFPGrowthSuite.java | 29 +- .../spark/mllib/fpm/JavaPrefixSpanSuite.java | 26 +- .../spark/mllib/linalg/JavaMatricesSuite.java | 278 ++++++++++--------- .../spark/mllib/linalg/JavaVectorsSuite.java | 7 +- .../spark/mllib/random/JavaRandomRDDsSuite.java | 136 ++++----- .../mllib/recommendation/JavaALSSuite.java | 64 +++-- .../regression/JavaIsotonicRegressionSuite.java | 22 +- .../spark/mllib/regression/JavaLassoSuite.java | 32 ++- .../regression/JavaLinearRegressionSuite.java | 42 +-- .../regression/JavaRidgeRegressionSuite.java | 22 +- .../spark/mllib/stat/JavaStatisticsSuite.java | 32 ++- .../spark/mllib/tree/JavaDecisionTreeSuite.java | 24 +- .../org/apache/spark/ml/PipelineSuite.scala | 2 +- .../ml/classification/ClassifierSuite.scala | 4 +- .../DecisionTreeClassifierSuite.scala | 4 +- .../ml/classification/GBTClassifierSuite.scala | 6 +- .../LogisticRegressionSuite.scala | 12 +- .../MultilayerPerceptronClassifierSuite.scala | 8 +- .../ml/classification/NaiveBayesSuite.scala | 12 +- .../ml/classification/OneVsRestSuite.scala | 4 +- .../RandomForestClassifierSuite.scala | 4 +- .../ml/clustering/BisectingKMeansSuite.scala | 2 +- .../ml/clustering/GaussianMixtureSuite.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 10 +- .../apache/spark/ml/clustering/LDASuite.scala | 16 +- .../BinaryClassificationEvaluatorSuite.scala | 8 +- ...MulticlassClassificationEvaluatorSuite.scala | 2 +- .../evaluation/RegressionEvaluatorSuite.scala | 4 +- .../spark/ml/feature/BinarizerSuite.scala | 8 +- .../spark/ml/feature/BucketizerSuite.scala | 8 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 9 +- .../spark/ml/feature/CountVectorizerSuite.scala | 14 +- .../org/apache/spark/ml/feature/DCTSuite.scala | 2 +- .../spark/ml/feature/HashingTFSuite.scala | 4 +- .../org/apache/spark/ml/feature/IDFSuite.scala | 4 +- .../spark/ml/feature/InteractionSuite.scala | 12 +- .../spark/ml/feature/MaxAbsScalerSuite.scala | 2 +- .../spark/ml/feature/MinMaxScalerSuite.scala | 4 +- .../apache/spark/ml/feature/NGramSuite.scala | 8 +- .../spark/ml/feature/NormalizerSuite.scala | 2 +- .../spark/ml/feature/OneHotEncoderSuite.scala | 6 +- .../org/apache/spark/ml/feature/PCASuite.scala | 2 +- .../ml/feature/PolynomialExpansionSuite.scala | 6 +- .../apache/spark/ml/feature/RFormulaSuite.scala | 44 +-- .../spark/ml/feature/SQLTransformerSuite.scala | 6 +- .../spark/ml/feature/StandardScalerSuite.scala | 8 +- .../ml/feature/StopWordsRemoverSuite.scala | 16 +- .../spark/ml/feature/StringIndexerSuite.scala | 18 +- .../spark/ml/feature/TokenizerSuite.scala | 8 +- .../spark/ml/feature/VectorAssemblerSuite.scala | 6 +- .../spark/ml/feature/VectorIndexerSuite.scala | 12 +- .../spark/ml/feature/VectorSlicerSuite.scala | 2 +- .../apache/spark/ml/feature/Word2VecSuite.scala | 16 +- .../spark/ml/recommendation/ALSSuite.scala | 21 +- .../regression/AFTSurvivalRegressionSuite.scala | 8 +- .../regression/DecisionTreeRegressorSuite.scala | 2 +- .../spark/ml/regression/GBTRegressorSuite.scala | 6 +- .../GeneralizedLinearRegressionSuite.scala | 32 +-- .../ml/regression/IsotonicRegressionSuite.scala | 8 +- .../ml/regression/LinearRegressionSuite.scala | 18 +- .../regression/RandomForestRegressorSuite.scala | 2 +- .../ml/source/libsvm/LibSVMRelationSuite.scala | 14 +- .../tree/impl/GradientBoostedTreesSuite.scala | 4 +- .../apache/spark/ml/tree/impl/TreeTests.scala | 10 +- .../spark/ml/tuning/CrossValidatorSuite.scala | 4 +- .../ml/tuning/TrainValidationSplitSuite.scala | 4 +- .../apache/spark/ml/util/MLTestingUtils.scala | 28 +- .../mllib/util/MLlibTestSparkContext.scala | 24 +- .../apache/spark/sql/JavaApplySchemaSuite.java | 42 +-- .../apache/spark/sql/JavaDataFrameSuite.java | 70 +++-- .../org/apache/spark/sql/JavaDatasetSuite.java | 89 +++--- .../test/org/apache/spark/sql/JavaUDFSuite.java | 27 +- .../sources/JavaDatasetAggregatorSuiteBase.java | 20 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 33 ++- .../org/apache/spark/sql/CachedTableSuite.scala | 216 +++++++------- .../spark/sql/ColumnExpressionSuite.scala | 12 +- .../spark/sql/DataFrameAggregateSuite.scala | 12 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 2 +- .../apache/spark/sql/DataFramePivotSuite.scala | 6 +- .../apache/spark/sql/DataFrameStatSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 82 +++--- .../spark/sql/DataFrameTimeWindowingSuite.scala | 8 +- .../spark/sql/DataFrameTungstenSuite.scala | 4 +- .../org/apache/spark/sql/DatasetBenchmark.scala | 38 +-- .../apache/spark/sql/DatasetCacheSuite.scala | 13 +- .../org/apache/spark/sql/DatasetSuite.scala | 26 +- .../apache/spark/sql/ExtraStrategiesSuite.scala | 4 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 22 +- .../org/apache/spark/sql/ListTablesSuite.scala | 20 +- .../apache/spark/sql/LocalSparkSession.scala | 68 +++++ .../scala/org/apache/spark/sql/QueryTest.scala | 8 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 72 ++--- .../apache/spark/sql/SerializationSuite.scala | 4 +- .../scala/org/apache/spark/sql/StreamTest.scala | 2 +- .../apache/spark/sql/StringFunctionsSuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 52 ++-- .../apache/spark/sql/UserDefinedTypeSuite.scala | 10 +- .../execution/ExchangeCoordinatorSuite.scala | 45 +-- .../spark/sql/execution/ExchangeSuite.scala | 4 +- .../spark/sql/execution/PlannerSuite.scala | 36 +-- .../spark/sql/execution/SQLExecutionSuite.scala | 19 +- .../spark/sql/execution/SparkPlanTest.scala | 27 +- .../sql/execution/WholeStageCodegenSuite.scala | 20 +- .../columnar/InMemoryColumnarQuerySuite.scala | 39 +-- .../columnar/PartitionBatchPruningSuite.scala | 19 +- .../spark/sql/execution/command/DDLSuite.scala | 39 +-- .../datasources/FileSourceStrategySuite.scala | 6 +- .../datasources/HadoopFsRelationSuite.scala | 4 +- .../execution/datasources/csv/CSVSuite.scala | 80 +++--- .../json/JsonParsingOptionsSuite.scala | 48 ++-- .../execution/datasources/json/JsonSuite.scala | 124 ++++----- .../datasources/json/TestJsonData.scala | 46 +-- .../parquet/ParquetAvroCompatibilitySuite.scala | 14 +- .../parquet/ParquetCompatibilityTest.scala | 2 +- .../parquet/ParquetFilterSuite.scala | 34 +-- .../datasources/parquet/ParquetIOSuite.scala | 49 ++-- .../parquet/ParquetInteroperabilitySuite.scala | 2 +- .../ParquetPartitionDiscoverySuite.scala | 34 +-- .../datasources/parquet/ParquetQuerySuite.scala | 102 +++---- .../parquet/ParquetReadBenchmark.scala | 76 ++--- .../parquet/ParquetSchemaSuite.scala | 14 +- .../datasources/parquet/ParquetTest.scala | 10 +- .../ParquetThriftCompatibilitySuite.scala | 4 +- .../datasources/parquet/TPCDSBenchmark.scala | 21 +- .../execution/datasources/text/TextSuite.scala | 20 +- .../sql/execution/debug/DebuggingSuite.scala | 2 +- .../execution/joins/BroadcastJoinSuite.scala | 18 +- .../sql/execution/joins/InnerJoinSuite.scala | 10 +- .../sql/execution/joins/OuterJoinSuite.scala | 8 +- .../sql/execution/metric/SQLMetricsSuite.scala | 30 +- .../streaming/FileStreamSinkLogSuite.scala | 4 +- .../streaming/HDFSMetadataLogSuite.scala | 33 +-- .../streaming/state/StateStoreRDDSuite.scala | 58 ++-- .../sql/execution/ui/SQLListenerSuite.scala | 36 +-- .../spark/sql/internal/CatalogSuite.scala | 69 +++-- .../spark/sql/internal/SQLConfSuite.scala | 74 ++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 34 +-- .../apache/spark/sql/jdbc/JDBCWriteSuite.scala | 46 +-- .../spark/sql/sources/DataSourceTest.scala | 2 +- .../sql/sources/PartitionedWriteSuite.scala | 10 +- .../streaming/ContinuousQueryManagerSuite.scala | 50 ++-- .../streaming/DataFrameReaderWriterSuite.scala | 82 +++--- .../sql/streaming/FileStreamSinkSuite.scala | 18 +- .../sql/streaming/FileStreamSourceSuite.scala | 6 +- .../spark/sql/streaming/FileStressSuite.scala | 4 +- .../spark/sql/streaming/MemorySinkSuite.scala | 4 +- .../spark/sql/streaming/StreamSuite.scala | 10 +- .../org/apache/spark/sql/test/SQLTestData.scala | 56 ++-- .../apache/spark/sql/test/SQLTestUtils.scala | 38 +-- .../spark/sql/test/SharedSQLContext.scala | 29 +- .../apache/spark/sql/test/TestSQLContext.scala | 48 ++-- .../sql/util/ContinuousQueryListenerSuite.scala | 20 +- .../spark/sql/util/DataFrameCallbackSuite.scala | 18 +- .../spark/sql/hive/test/TestHiveSingleton.scala | 4 +- .../sql/catalyst/ExpressionToSQLSuite.scala | 4 +- .../sql/catalyst/LogicalPlanToSQLSuite.scala | 12 +- .../spark/sql/catalyst/SQLBuilderTest.scala | 2 +- .../spark/sql/hive/ErrorPositionSuite.scala | 8 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 16 +- .../sql/hive/InsertIntoHiveTableSuite.scala | 6 +- .../sql/hive/MetastoreDataSourcesSuite.scala | 18 +- .../spark/sql/hive/MultiDatabaseSuite.scala | 88 +++--- .../hive/ParquetHiveCompatibilitySuite.scala | 12 +- .../hive/execution/AggregationQuerySuite.scala | 110 ++++---- .../spark/sql/hive/execution/HiveDDLSuite.scala | 18 +- .../spark/sql/hive/execution/HiveUDFSuite.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 32 +-- .../spark/sql/hive/execution/SQLViewSuite.scala | 14 +- .../hive/execution/SQLWindowFunctionSuite.scala | 2 +- .../sql/hive/orc/OrcHadoopFsRelationSuite.scala | 8 +- .../spark/sql/hive/orc/OrcQuerySuite.scala | 38 +-- .../org/apache/spark/sql/hive/orc/OrcTest.scala | 4 +- .../apache/spark/sql/hive/parquetSuites.scala | 2 +- .../spark/sql/sources/BucketedWriteSuite.scala | 2 +- .../CommitFailureTestRelationSuite.scala | 6 +- .../sql/sources/HadoopFsRelationTest.scala | 66 ++--- .../sources/ParquetHadoopFsRelationSuite.scala | 20 +- .../SimpleTextHadoopFsRelationSuite.scala | 4 +- 224 files changed, 2934 insertions(+), 2611 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 60a4a1d..e0c4363 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,18 +17,18 @@ package org.apache.spark.ml; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; /** @@ -36,23 +36,26 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite.gene */ public class JavaPipelineSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset<Row> dataset; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPipelineSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); JavaRDD<LabeledPoint> points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); - dataset = jsql.createDataFrame(points, LabeledPoint.class); + dataset = spark.createDataFrame(points, LabeledPoint.class); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -63,10 +66,10 @@ public class JavaPipelineSuite { LogisticRegression lr = new LogisticRegression() .setFeaturesCol("scaledFeatures"); Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {scaler, lr}); + .setStages(new PipelineStage[]{scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java index b74bbed..15cde0d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/attribute/JavaAttributeSuite.java @@ -17,8 +17,8 @@ package org.apache.spark.ml.attribute; -import org.junit.Test; import org.junit.Assert; +import org.junit.Test; public class JavaAttributeSuite { http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 1f23682..8b89991 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,8 +21,6 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -32,21 +30,28 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.tree.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; - +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaDecisionTreeClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaDecisionTreeClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +60,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -70,7 +75,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeClassifier.supportedImpurities()) { + for (String impurity : DecisionTreeClassifier.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeClassificationModel model = dt.fit(dataFrame); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 7484105..682371e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaGBTClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGBTClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaGBTClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -74,7 +80,7 @@ public class JavaGBTClassifierSuite implements Serializable { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTClassifier.supportedLossTypes()) { + for (String lossType : GBTClassifier.supportedLossTypes()) { rf.setLossType(lossType); } GBTClassificationModel model = rf.fit(dataFrame); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index e160a5a..e3ff683 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -27,18 +27,17 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLogisticRegressionSuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @@ -46,18 +45,22 @@ public class JavaLogisticRegressionSuite implements Serializable { @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); + List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); dataset.registerTempTable("dataset"); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -66,7 +69,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -95,23 +98,23 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); - for (Row r: predAllZero.collectAsList()) { + Dataset<Row> predAllZero = spark.sql("SELECT prediction, myProbability FROM predAllZero"); + for (Row r : predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset<Row> predNotAllZero = spark.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; - for (Row r: predNotAllZero.collectAsList()) { + for (Row r : predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; } Assert.assertTrue(foundNonZero); // Call fit() with new params, and check as many params as we can. LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), - lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); + lr.threshold().w(0.4), lr.probabilityCol().w("theProb")); LogisticRegression parent2 = (LogisticRegression) model2.parent(); Assert.assertEquals(5, parent2.getMaxIter()); Assert.assertEquals(0.1, parent2.getRegParam(), eps); @@ -128,10 +131,10 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collectAsList()) { - Vector raw = (Vector)row.get(0); - Vector prob = (Vector)row.get(1); + Dataset<Row> trans1 = spark.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row : trans1.collectAsList()) { + Vector raw = (Vector) row.get(0); + Vector prob = (Vector) row.get(1); Assert.assertEquals(raw.size(), 2); Assert.assertEquals(prob.size(), 2); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); @@ -139,11 +142,11 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collectAsList()) { + Dataset<Row> trans2 = spark.sql("SELECT prediction, probability FROM transformed"); + for (Row row : trans2.collectAsList()) { double pred = row.getDouble(0); - Vector prob = (Vector)row.get(1); - double probOfPred = prob.apply((int)pred); + Vector prob = (Vector) row.get(1); + double probOfPred = prob.apply((int) pred); for (int i = 0; i < prob.size(); ++i) { Assert.assertTrue(probOfPred >= prob.apply(i)); } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index bc955f3..b0624ce 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -26,49 +26,49 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaMultilayerPerceptronClassifierSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; - sqlContext = null; + spark.stop(); + spark = null; } @Test public void testMLPC() { - Dataset<Row> dataFrame = sqlContext.createDataFrame( - jsc.parallelize(Arrays.asList( - new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), - new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), - new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), - new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)))), - LabeledPoint.class); + List<LabeledPoint> data = Arrays.asList( + new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + new LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + new LabeledPoint(0.0, Vectors.dense(1.0, 1.0)) + ); + Dataset<Row> dataFrame = spark.createDataFrame(data, LabeledPoint.class); + MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() - .setLayers(new int[] {2, 5, 2}) + .setLayers(new int[]{2, 5, 2}) .setBlockSize(1) .setSeed(123L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); Dataset<Row> result = model.transform(dataFrame); List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); - for (Row r: predictionAndLabels) { + for (Row r : predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 45101f2..3fc3648 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -26,13 +26,12 @@ import org.junit.Before; import org.junit.Test; import static org.junit.Assert.assertEquals; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -40,19 +39,20 @@ import org.apache.spark.sql.types.StructType; public class JavaNaiveBayesSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaLogisticRegressionSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } public void validatePrediction(Dataset<Row> predictionAndLabels) { @@ -88,7 +88,7 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - Dataset<Row> dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = spark.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 00f4476..486fbbd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,7 +20,6 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; -import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -30,56 +29,61 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; public class JavaOneVsRestSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; - private transient Dataset<Row> dataset; - private transient JavaRDD<LabeledPoint> datasetRDD; + private transient SparkSession spark; + private transient JavaSparkContext jsc; + private transient Dataset<Row> dataset; + private transient JavaRDD<LabeledPoint> datasetRDD; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local") + .appName("JavaLOneVsRestSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); - @Before - public void setUp() { - jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite"); - jsql = new SQLContext(jsc); - int nPoints = 3; + int nPoints = 3; - // The following coefficients and xMean/xVariance are computed from iris dataset with - // lambda=0.2. - // As a result, we are drawing samples from probability distribution of an actual model. - double[] coefficients = { - -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, - -0.16624, -0.84355, -0.048509, -0.301789, 4.170682 }; + // The following coefficients and xMean/xVariance are computed from iris dataset with + // lambda=0.2. + // As a result, we are drawing samples from probability distribution of an actual model. + double[] coefficients = { + -0.57997, 0.912083, -0.371077, -0.819866, 2.688191, + -0.16624, -0.84355, -0.048509, -0.301789, 4.170682}; - double[] xMean = {5.843, 3.057, 3.758, 1.199}; - double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; - List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter( - generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) - ).asJava(); - datasetRDD = jsc.parallelize(points, 2); - dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class); - } + double[] xMean = {5.843, 3.057, 3.758, 1.199}; + double[] xVariance = {0.6856, 0.1899, 3.116, 0.581}; + List<LabeledPoint> points = JavaConverters.seqAsJavaListConverter( + generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42) + ).asJava(); + datasetRDD = jsc.parallelize(points, 2); + dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class); + } - @After - public void tearDown() { - jsc.stop(); - jsc = null; - } + @After + public void tearDown() { + spark.stop(); + spark = null; + } - @Test - public void oneVsRestDefaultParams() { - OneVsRest ova = new OneVsRest(); - ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol() , "label"); - Assert.assertEquals(ova.getPredictionCol() , "prediction"); - OneVsRestModel ovaModel = ova.fit(dataset); - Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); - predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); - } + @Test + public void oneVsRestDefaultParams() { + OneVsRest ova = new OneVsRest(); + ova.setClassifier(new LogisticRegression()); + Assert.assertEquals(ova.getLabelCol(), "label"); + Assert.assertEquals(ova.getPredictionCol(), "prediction"); + OneVsRestModel ovaModel = ova.fit(dataset); + Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); + predictions.collectAsList(); + Assert.assertEquals(ovaModel.getLabelCol(), "label"); + Assert.assertEquals(ovaModel.getPredictionCol(), "prediction"); + } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 4f40fd6..e385566 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -34,21 +34,27 @@ import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaRandomForestClassifierSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaRandomForestClassifierSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaRandomForestClassifierSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,7 +63,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); @@ -75,22 +81,22 @@ public class JavaRandomForestClassifierSuite implements Serializable { .setSeed(1234) .setNumTrees(3) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: RandomForestClassifier.supportedImpurities()) { + for (String impurity : RandomForestClassifier.supportedImpurities()) { rf.setImpurity(impurity); } - for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { + for (String featureSubsetStrategy : RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } String[] realStrategies = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; - for (String strategy: realStrategies) { + for (String strategy : realStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] integerStrategies = {"1", "10", "100", "1000", "10000"}; - for (String strategy: integerStrategies) { + for (String strategy : integerStrategies) { rf.setFeatureSubsetStrategy(strategy); } String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; - for (String strategy: invalidStrategies) { + for (String strategy : invalidStrategies) { try { rf.setFeatureSubsetStrategy(strategy); Assert.fail("Expected exception to be thrown for invalid strategies"); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index a3fcdb5..3ab09ac 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -21,37 +21,37 @@ import java.io.Serializable; import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaKMeansSuite implements Serializable { private transient int k = 5; - private transient JavaSparkContext sc; private transient Dataset<Row> dataset; - private transient SQLContext sql; + private transient SparkSession spark; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaKMeansSuite"); - sql = new SQLContext(sc); - - dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k); + spark = SparkSession.builder() + .master("local") + .appName("JavaKMeansSuite") + .getOrCreate(); + dataset = KMeansSuite.generateKMeansData(spark, 50, 3, k); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -65,7 +65,7 @@ public class JavaKMeansSuite implements Serializable { Dataset<Row> transformed = model.transform(dataset); List<String> columns = Arrays.asList(transformed.columns()); List<String> expectedColumns = Arrays.asList("features", "prediction"); - for (String column: expectedColumns) { + for (String column : expectedColumns) { assertTrue(columns.contains(column)); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index 77e3a48..a96b43d 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -25,40 +25,40 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaBucketizerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaBucketizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaBucketizerSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void bucketizerTest() { double[] splits = {-0.5, 0.0, 0.5}; - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - Dataset<Row> dataset = jsql.createDataFrame( + Dataset<Row> dataset = spark.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index ed1ad4c..06482d8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -21,43 +21,44 @@ import java.util.Arrays; import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; + import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaDCTSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaDCTSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaDCTSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void javaCompatibilityTest() { - double[] input = new double[] {1D, 2D, 3D, 4D}; - Dataset<Row> dataset = jsql.createDataFrame( + double[] input = new double[]{1D, 2D, 3D, 4D}; + Dataset<Row> dataset = spark.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 6e2cc7e..0e21d4a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -25,12 +25,11 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -38,19 +37,20 @@ import org.apache.spark.sql.types.StructType; public class JavaHashingTFSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaHashingTFSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaHashingTFSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -65,7 +65,7 @@ public class JavaHashingTFSuite { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - Dataset<Row> sentenceData = jsql.createDataFrame(data, schema); + Dataset<Row> sentenceData = spark.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index 5bbd963..04b2897 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -23,27 +23,30 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaNormalizerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaNormalizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaNormalizerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +57,7 @@ public class JavaNormalizerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset<Row> dataFrame = spark.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index 1389d17..32f6b43 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -28,31 +28,34 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.linalg.distributed.RowMatrix; +import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaPCASuite implements Serializable { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPCASuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPCASuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } public static class VectorPair implements Serializable { @@ -100,7 +103,7 @@ public class JavaPCASuite implements Serializable { } ); - Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset<Row> df = spark.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index 6a8bb64..8f72607 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -32,19 +32,22 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; public class JavaPolynomialExpansionSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaPolynomialExpansionSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaPolynomialExpansionSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After @@ -72,20 +75,20 @@ public class JavaPolynomialExpansionSuite { ) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("features", new VectorUDT(), false, Metadata.empty()), new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - Dataset<Row> dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = spark.createDataFrame(data, schema); List<Row> pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") .collectAsList(); for (Row r : pairs) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); - double[] expected = ((Vector)r.get(1)).toArray(); + double[] polyFeatures = ((Vector) r.get(0)).toArray(); + double[] expected = ((Vector) r.get(1)).toArray(); Assert.assertArrayEquals(polyFeatures, expected, 1e-1); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index 3f6fc33..c7397bd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -28,22 +28,25 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaStandardScalerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStandardScalerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaStandardScalerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -54,7 +57,7 @@ public class JavaStandardScalerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset<Row> dataFrame = spark.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index bdcbde5..2b156f3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -24,11 +24,10 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; @@ -37,19 +36,20 @@ import org.apache.spark.sql.types.StructType; public class JavaStopWordsRemoverSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStopWordsRemoverSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaStopWordsRemoverSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -62,11 +62,11 @@ public class JavaStopWordsRemoverSuite { RowFactory.create(Arrays.asList("I", "saw", "the", "red", "baloon")), RowFactory.create(Arrays.asList("Mary", "had", "a", "little", "lamb")) ); - StructType schema = new StructType(new StructField[] { + StructType schema = new StructType(new StructField[]{ new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, - Metadata.empty()) + Metadata.empty()) }); - Dataset<Row> dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = spark.createDataFrame(data, schema); remover.transform(dataset).collect(); } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 431779c..52c0bde 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -25,40 +25,42 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.*; public class JavaStringIndexerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaStringIndexerSuite"); - sqlContext = new SQLContext(jsc); + SparkConf sparkConf = new SparkConf(); + sparkConf.setMaster("local"); + sparkConf.setAppName("JavaStringIndexerSuite"); + + spark = SparkSession.builder().config(sparkConf).getOrCreate(); } @After public void tearDown() { - jsc.stop(); - sqlContext = null; + spark.stop(); + spark = null; } @Test public void testStringIndexer() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - Dataset<Row> dataset = sqlContext.createDataFrame(data, schema); + Dataset<Row> dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") @@ -70,7 +72,9 @@ public class JavaStringIndexerSuite { output.orderBy("id").select("id", "labelIndex").collectAsList()); } - /** An alias for RowFactory.create. */ + /** + * An alias for RowFactory.create. + */ private Row cr(Object... values) { return RowFactory.create(values); } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index 83d16cb..0bac283 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -29,22 +29,25 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaTokenizerSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; - private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaTokenizerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaTokenizerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -59,10 +62,10 @@ public class JavaTokenizerSuite { JavaRDD<TokenizerTestData> rdd = jsc.parallelize(Arrays.asList( - new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), - new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) + new TokenizerTestData("Test of tok.", new String[]{"Test", "tok."}), + new TokenizerTestData("Te,st. punct", new String[]{"Te,st.", "punct"}) )); - Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset<Row> dataset = spark.createDataFrame(rdd, TokenizerTestData.class); List<Row> pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index e45e198..8774cd0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -24,36 +24,39 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.SparkConf; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.types.DataTypes.*; public class JavaVectorAssemblerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorAssemblerSuite"); - sqlContext = new SQLContext(jsc); + SparkConf sparkConf = new SparkConf(); + sparkConf.setMaster("local"); + sparkConf.setAppName("JavaVectorAssemblerSuite"); + + spark = SparkSession.builder().config(sparkConf).getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test public void testVectorAssembler() { - StructType schema = createStructType(new StructField[] { + StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("x", DoubleType, false), createStructField("y", new VectorUDT(), false), @@ -63,14 +66,14 @@ public class JavaVectorAssemblerSuite { }); Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", - Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Vectors.sparse(2, new int[]{1}, new double[]{3.0}), 10L); + Dataset<Row> dataset = spark.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() - .setInputCols(new String[] {"x", "y", "z", "n"}) + .setInputCols(new String[]{"x", "y", "z", "n"}) .setOutputCol("features"); Dataset<Row> output = assembler.transform(dataset); Assert.assertEquals( - Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), + Vectors.sparse(6, new int[]{1, 2, 4, 5}, new double[]{1.0, 2.0, 3.0, 10.0}), output.select("features").first().<Vector>getAs(0)); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index fec6cac..c386c9a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -32,21 +32,26 @@ import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; public class JavaVectorIndexerSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaVectorIndexerSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaVectorIndexerSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -57,8 +62,7 @@ public class JavaVectorIndexerSuite implements Serializable { new FeatureData(Vectors.dense(1.0, 3.0)), new FeatureData(Vectors.dense(1.0, 4.0)) ); - SQLContext sqlContext = new SQLContext(sc); - Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset<Row> data = spark.createDataFrame(jsc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index e2da111..59ad3c2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -25,7 +25,6 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.attribute.Attribute; import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; @@ -34,24 +33,25 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.StructType; public class JavaVectorSlicerSuite { - private transient JavaSparkContext jsc; - private transient SQLContext jsql; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaVectorSlicerSuite"); - jsql = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaVectorSlicerSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -69,7 +69,7 @@ public class JavaVectorSlicerSuite { ); Dataset<Row> dataset = - jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + spark.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index 7517b70..392aabc 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -24,28 +24,28 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.*; public class JavaWord2VecSuite { - private transient JavaSparkContext jsc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaWord2VecSuite"); - sqlContext = new SQLContext(jsc); + spark = SparkSession.builder() + .master("local") + .appName("JavaWord2VecSuite") + .getOrCreate(); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -53,7 +53,7 @@ public class JavaWord2VecSuite { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - Dataset<Row> documentDF = sqlContext.createDataFrame( + Dataset<Row> documentDF = spark.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -68,8 +68,8 @@ public class JavaWord2VecSuite { Word2VecModel model = word2Vec.fit(documentDF); Dataset<Row> result = model.transform(documentDF); - for (Row r: result.select("result").collectAsList()) { - double[] polyFeatures = ((Vector)r.get(0)).toArray(); + for (Row r : result.select("result").collectAsList()) { + double[] polyFeatures = ((Vector) r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index fa777f3..a5b5dd4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -25,23 +25,29 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.SparkSession; /** * Test Param and related classes in Java */ public class JavaParamsSuite { + private transient SparkSession spark; private transient JavaSparkContext jsc; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaParamsSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaParamsSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - jsc.stop(); - jsc = null; + spark.stop(); + spark = null; } @Test @@ -51,7 +57,7 @@ public class JavaParamsSuite { testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); Assert.assertEquals(testParams.getMyStringParam(), "a"); - Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[] {1.0, 2.0}, 0.0); + Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0); } @Test http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 06f7fbb..1ad5f7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -45,9 +45,14 @@ public class JavaTestParams extends JavaParams { } private IntParam myIntParam_; - public IntParam myIntParam() { return myIntParam_; } - public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); } + public IntParam myIntParam() { + return myIntParam_; + } + + public int getMyIntParam() { + return (Integer) getOrDefault(myIntParam_); + } public JavaTestParams setMyIntParam(int value) { set(myIntParam_, value); @@ -55,9 +60,14 @@ public class JavaTestParams extends JavaParams { } private DoubleParam myDoubleParam_; - public DoubleParam myDoubleParam() { return myDoubleParam_; } - public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); } + public DoubleParam myDoubleParam() { + return myDoubleParam_; + } + + public double getMyDoubleParam() { + return (Double) getOrDefault(myDoubleParam_); + } public JavaTestParams setMyDoubleParam(double value) { set(myDoubleParam_, value); @@ -65,9 +75,14 @@ public class JavaTestParams extends JavaParams { } private Param<String> myStringParam_; - public Param<String> myStringParam() { return myStringParam_; } - public String getMyStringParam() { return getOrDefault(myStringParam_); } + public Param<String> myStringParam() { + return myStringParam_; + } + + public String getMyStringParam() { + return getOrDefault(myStringParam_); + } public JavaTestParams setMyStringParam(String value) { set(myStringParam_, value); @@ -75,9 +90,14 @@ public class JavaTestParams extends JavaParams { } private DoubleArrayParam myDoubleArrayParam_; - public DoubleArrayParam myDoubleArrayParam() { return myDoubleArrayParam_; } - public double[] getMyDoubleArrayParam() { return getOrDefault(myDoubleArrayParam_); } + public DoubleArrayParam myDoubleArrayParam() { + return myDoubleArrayParam_; + } + + public double[] getMyDoubleArrayParam() { + return getOrDefault(myDoubleArrayParam_); + } public JavaTestParams setMyDoubleArrayParam(double[] value) { set(myDoubleArrayParam_, value); @@ -96,7 +116,7 @@ public class JavaTestParams extends JavaParams { setDefault(myIntParam(), 1); setDefault(myDoubleParam(), 0.5); - setDefault(myDoubleArrayParam(), new double[] {1.0, 2.0}); + setDefault(myDoubleArrayParam(), new double[]{1.0, 2.0}); } @Override http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index fa3b28e..bbd59a0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaDecisionTreeRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaDecisionTreeRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -70,7 +76,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { .setCacheNodeIds(false) .setCheckpointInterval(10) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String impurity: DecisionTreeRegressor.supportedImpurities()) { + for (String impurity : DecisionTreeRegressor.supportedImpurities()) { dt.setImpurity(impurity); } DecisionTreeRegressionModel model = dt.fit(dataFrame); http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 8413ea0..5370b58 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -32,21 +32,27 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; public class JavaGBTRegressorSuite implements Serializable { - private transient JavaSparkContext sc; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - sc = new JavaSparkContext("local", "JavaGBTRegressorSuite"); + spark = SparkSession.builder() + .master("local") + .appName("JavaGBTRegressorSuite") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sc.stop(); - sc = null; + spark.stop(); + spark = null; } @Test @@ -55,7 +61,7 @@ public class JavaGBTRegressorSuite implements Serializable { double A = 2.0; double B = -1.5; - JavaRDD<LabeledPoint> data = sc.parallelize( + JavaRDD<LabeledPoint> data = jsc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); @@ -73,7 +79,7 @@ public class JavaGBTRegressorSuite implements Serializable { .setMaxIter(3) .setStepSize(0.1) .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern - for (String lossType: GBTRegressor.supportedLossTypes()) { + for (String lossType : GBTRegressor.supportedLossTypes()) { rf.setLossType(lossType); } GBTRegressionModel model = rf.fit(dataFrame); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
