http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 8e7e000..125ad02 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { @@ -42,7 +42,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol("filtered") - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("test", "test"), Seq("test", "test")), (Seq("a", "b", "c", "d"), Seq("b", "c")), (Seq("a", "the", "an"), Seq()), @@ -60,7 +60,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("test", "test"), Seq()), (Seq("a", "b", "c", "d"), Seq("b", "c", "d")), (Seq("a", "the", "an"), Seq()), @@ -77,7 +77,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setCaseSensitive(true) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("A"), Seq("A")), (Seq("The", "the"), Seq("The")) )).toDF("raw", "expected") @@ -98,7 +98,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) )).toDF("raw", "expected") @@ -112,7 +112,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("python", "scala", "a"), Seq("python", "scala", "a")), (Seq("Python", "Scala", "swift"), Seq("Python", "Scala", "swift")) )).toDF("raw", "expected") @@ -126,7 +126,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords.toArray) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("python", "scala", "a"), Seq()), (Seq("Python", "Scala", "swift"), Seq("swift")) )).toDF("raw", "expected") @@ -148,7 +148,7 @@ class StopWordsRemoverSuite val remover = new StopWordsRemover() .setInputCol("raw") .setOutputCol(outputCol) - val dataSet = sqlContext.createDataFrame(Seq( + val dataSet = spark.createDataFrame(Seq( (Seq("The", "the", "swift"), Seq("swift")) )).toDF("raw", outputCol)
http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index d0f3cdc..c221d4a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -39,7 +39,7 @@ class StringIndexerSuite test("StringIndexer") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -63,8 +63,8 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") - val df2 = sqlContext.createDataFrame(data2).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") + val df2 = spark.createDataFrame(data2).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -93,7 +93,7 @@ class StringIndexerSuite test("StringIndexer with a numeric input column") { val data = sc.parallelize(Seq((0, 100), (1, 200), (2, 300), (3, 100), (4, 100), (5, 300)), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -114,12 +114,12 @@ class StringIndexerSuite val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") .setOutputCol("labelIndex") - val df = sqlContext.range(0L, 10L).toDF() + val df = spark.range(0L, 10L).toDF() assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { - val df = sqlContext.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") + val df = spark.createDataFrame(Seq((1, 2), (3, 4))).toDF("input", "output") val indexer = new StringIndexer() .setInputCol("input") .setOutputCol("output") @@ -153,7 +153,7 @@ class StringIndexerSuite test("IndexToString.transform") { val labels = Array("a", "b", "c") - val df0 = sqlContext.createDataFrame(Seq( + val df0 = spark.createDataFrame(Seq( (0, "a"), (1, "b"), (2, "c"), (0, "a") )).toDF("index", "expected") @@ -180,7 +180,7 @@ class StringIndexerSuite test("StringIndexer, IndexToString are inverses") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") @@ -213,7 +213,7 @@ class StringIndexerSuite test("StringIndexer metadata") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) - val df = sqlContext.createDataFrame(data).toDF("id", "label") + val df = spark.createDataFrame(data).toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 123ddfe..f30bdc3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -57,13 +57,13 @@ class RegexTokenizerSuite .setPattern("\\w+|\\p{Punct}") .setInputCol("rawText") .setOutputCol("tokens") - val dataset0 = sqlContext.createDataFrame(Seq( + val dataset0 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization", ".")), TokenizerTestData("Te,st. punct", Array("te", ",", "st", ".", "punct")) )) testRegexTokenizer(tokenizer0, dataset0) - val dataset1 = sqlContext.createDataFrame(Seq( + val dataset1 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization")), TokenizerTestData("Te,st. punct", Array("punct")) )) @@ -73,7 +73,7 @@ class RegexTokenizerSuite val tokenizer2 = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") - val dataset2 = sqlContext.createDataFrame(Seq( + val dataset2 = spark.createDataFrame(Seq( TokenizerTestData("Test for tokenization.", Array("test", "for", "tokenization.")), TokenizerTestData("Te,st. punct", Array("te,st.", "punct")) )) @@ -85,7 +85,7 @@ class RegexTokenizerSuite .setInputCol("rawText") .setOutputCol("tokens") .setToLowercase(false) - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( TokenizerTestData("JAVA SCALA", Array("JAVA", "SCALA")), TokenizerTestData("java scala", Array("java", "scala")) )) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index dce994f..250011c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -57,7 +57,7 @@ class VectorAssemblerSuite } test("VectorAssembler") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) )).toDF("id", "x", "y", "name", "z", "n") val assembler = new VectorAssembler() @@ -70,7 +70,7 @@ class VectorAssemblerSuite } test("transform should throw an exception in case of unsupported type") { - val df = sqlContext.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") + val df = spark.createDataFrame(Seq(("a", "b", "c"))).toDF("a", "b", "c") val assembler = new VectorAssembler() .setInputCols(Array("a", "b", "c")) .setOutputCol("features") @@ -87,7 +87,7 @@ class VectorAssemblerSuite NominalAttribute.defaultAttr.withName("gender").withValues("male", "female"), NumericAttribute.defaultAttr.withName("salary"))) val row = (1.0, 0.5, 1, Vectors.dense(1.0, 1000.0), Vectors.sparse(2, Array(1), Array(2.0))) - val df = sqlContext.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") + val df = spark.createDataFrame(Seq(row)).toDF("browser", "hour", "count", "user", "ad") .select( col("browser").as("browser", browser.toMetadata()), col("hour").as("hour", hour.toMetadata()), http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala index 1ffc62b..d1c0270 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala @@ -85,11 +85,11 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext checkPair(densePoints1Seq, sparsePoints1Seq) checkPair(densePoints2Seq, sparsePoints2Seq) - densePoints1 = sqlContext.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) - sparsePoints1 = sqlContext.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) - densePoints2 = sqlContext.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) - sparsePoints2 = sqlContext.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) - badPoints = sqlContext.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) + densePoints1 = spark.createDataFrame(sc.parallelize(densePoints1Seq, 2).map(FeatureData)) + sparsePoints1 = spark.createDataFrame(sc.parallelize(sparsePoints1Seq, 2).map(FeatureData)) + densePoints2 = spark.createDataFrame(sc.parallelize(densePoints2Seq, 2).map(FeatureData)) + sparsePoints2 = spark.createDataFrame(sc.parallelize(sparsePoints2Seq, 2).map(FeatureData)) + badPoints = spark.createDataFrame(sc.parallelize(badPointsSeq, 2).map(FeatureData)) } private def getIndexer: VectorIndexer = @@ -102,7 +102,7 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext } test("Cannot fit an empty DataFrame") { - val rdd = sqlContext.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) + val rdd = spark.createDataFrame(sc.parallelize(Array.empty[Vector], 2).map(FeatureData)) val vectorIndexer = getIndexer intercept[IllegalArgumentException] { vectorIndexer.fit(rdd) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 6bb4678..88a077f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -79,7 +79,7 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with De val resultAttrGroup = new AttributeGroup("expected", resultAttrs.asInstanceOf[Array[Attribute]]) val rdd = sc.parallelize(data.zip(expected)).map { case (a, b) => Row(a, b) } - val df = sqlContext.createDataFrame(rdd, + val df = spark.createDataFrame(rdd, StructType(Array(attrGroup.toStructField(), resultAttrGroup.toStructField()))) val vectorSlicer = new VectorSlicer().setInputCol("features").setOutputCol("result") http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 80c177b..8cbe0f3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -36,8 +36,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("Word2Vec") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val numOfWords = sentence.split(" ").size @@ -78,8 +78,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("getVectors") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -119,8 +119,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("findSynonyms") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a b " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) @@ -146,8 +146,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul test("window size") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10 val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" ")) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 1704037..9da0c32 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -305,8 +305,8 @@ class ALSSuite numUserBlocks: Int = 2, numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -460,8 +460,8 @@ class ALSSuite allEstimatorParamSettings.foreach { case (p, v) => als.set(als.getParam(p), v) } - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val model = als.fit(ratings.toDF()) // Test Estimator save/load @@ -535,8 +535,11 @@ class ALSCleanerSuite extends SparkFunSuite { // Generate test data val (training, _) = ALSSuite.genImplicitTestData(sc, 20, 5, 1, 0.2, 0) // Implicitly test the cleaning of parents during ALS training - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[2]") + .appName("ALSCleanerSuite") + .getOrCreate() + import spark.implicits._ val als = new ALS() .setRank(1) .setRegParam(1e-5) @@ -577,8 +580,8 @@ class ALSStorageSuite } test("default and non-default storage params set correct RDD StorageLevels") { - val sqlContext = this.sqlContext - import sqlContext.implicits._ + val spark = this.spark + import spark.implicits._ val data = Seq( (0, 0, 1.0), (0, 1, 2.0), http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 76891ad..f8fc775 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -37,13 +37,13 @@ class AFTSurvivalRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetUnivariate = sqlContext.createDataFrame( + datasetUnivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) - datasetMultivariate = sqlContext.createDataFrame( + datasetMultivariate = spark.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) - datasetUnivariateScaled = sqlContext.createDataFrame( + datasetUnivariateScaled = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) @@ -356,7 +356,7 @@ class AFTSurvivalRegressionSuite test("should support all NumericType labels") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( - aft, isClassification = false, sqlContext) { (expected, actual) => + aft, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index e9fb267..d9f26ad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -120,7 +120,7 @@ class DecisionTreeRegressorSuite test("should support all NumericType labels and not support other types") { val dt = new DecisionTreeRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( - dt, isClassification = false, sqlContext) { (expected, actual) => + dt, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 2163779..f6ea5bb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -72,7 +72,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext } test("GBTRegressor behaves reasonably on toy data") { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), @@ -99,7 +99,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val path = tempDir.toURI.toString sc.setCheckpointDir(path) - val df = sqlContext.createDataFrame(data) + val df = spark.createDataFrame(data) val gbt = new GBTRegressor() .setMaxDepth(2) .setMaxIter(5) @@ -115,7 +115,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext test("should support all NumericType labels and not support other types") { val gbt = new GBTRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( - gbt, isClassification = false, sqlContext) { (expected, actual) => + gbt, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index b854be2..161f8c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -52,19 +52,19 @@ class GeneralizedLinearRegressionSuite import GeneralizedLinearRegressionSuite._ - datasetGaussianIdentity = sqlContext.createDataFrame( + datasetGaussianIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "identity"), 2)) - datasetGaussianLog = sqlContext.createDataFrame( + datasetGaussianLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gaussian", link = "log"), 2)) - datasetGaussianInverse = sqlContext.createDataFrame( + datasetGaussianInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -80,40 +80,40 @@ class GeneralizedLinearRegressionSuite generateMultinomialLogisticInput(coefficients, xMean, xVariance, addIntercept = true, nPoints, seed) - sqlContext.createDataFrame(sc.parallelize(testData, 2)) + spark.createDataFrame(sc.parallelize(testData, 2)) } - datasetPoissonLog = sqlContext.createDataFrame( + datasetPoissonLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "log"), 2)) - datasetPoissonIdentity = sqlContext.createDataFrame( + datasetPoissonIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "identity"), 2)) - datasetPoissonSqrt = sqlContext.createDataFrame( + datasetPoissonSqrt = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "poisson", link = "sqrt"), 2)) - datasetGammaInverse = sqlContext.createDataFrame( + datasetGammaInverse = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "inverse"), 2)) - datasetGammaIdentity = sqlContext.createDataFrame( + datasetGammaIdentity = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 2.5, coefficients = Array(2.2, 0.6), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, family = "gamma", link = "identity"), 2)) - datasetGammaLog = sqlContext.createDataFrame( + datasetGammaLog = spark.createDataFrame( sc.parallelize(generateGeneralizedLinearRegressionInput( intercept = 0.25, coefficients = Array(0.22, 0.06), xMean = Array(2.9, 10.5), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, noiseLevel = 0.01, @@ -540,7 +540,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -668,7 +668,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(1.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 2.0)), Instance(1.0, 3.0, Vectors.dense(2.0, 1.0)), @@ -782,7 +782,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -899,7 +899,7 @@ class GeneralizedLinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - val datasetWithWeight = sqlContext.createDataFrame(sc.parallelize(Seq( + val datasetWithWeight = spark.createDataFrame(sc.parallelize(Seq( Instance(2.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(8.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(3.0, 3.0, Vectors.dense(2.0, 11.0)), @@ -1021,14 +1021,14 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[ GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( - glr, isClassification = false, sqlContext) { (expected, actual) => + glr, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } } test("glm accepts Dataset[LabeledPoint]") { - val context = sqlContext + val context = spark import context.implicits._ new GeneralizedLinearRegression() .setFamily("gaussian") http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 3a10ad7..9bf7542 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -28,13 +28,13 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private def generateIsotonicInput(labels: Seq[Double]): DataFrame = { - sqlContext.createDataFrame( + spark.createDataFrame( labels.zipWithIndex.map { case (label, i) => (label, i.toDouble, 1.0) } ).toDF("label", "features", "weight") } private def generatePredictionInput(features: Seq[Double]): DataFrame = { - sqlContext.createDataFrame(features.map(Tuple1.apply)) + spark.createDataFrame(features.map(Tuple1.apply)) .toDF("features") } @@ -145,7 +145,7 @@ class IsotonicRegressionSuite } test("vector features column with feature index") { - val dataset = sqlContext.createDataFrame(Seq( + val dataset = spark.createDataFrame(Seq( (4.0, Vectors.dense(0.0, 1.0)), (3.0, Vectors.dense(0.0, 2.0)), (5.0, Vectors.sparse(2, Array(1), Array(3.0)))) @@ -184,7 +184,7 @@ class IsotonicRegressionSuite test("should support all NumericType labels and not support other types") { val ir = new IsotonicRegression() MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( - ir, isClassification = false, sqlContext) { (expected, actual) => + ir, isClassification = false, spark) { (expected, actual) => assert(expected.boundaries === actual.boundaries) assert(expected.predictions === actual.predictions) } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index eb19d13..10f547b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -42,7 +42,7 @@ class LinearRegressionSuite override def beforeAll(): Unit = { super.beforeAll() - datasetWithDenseFeature = sqlContext.createDataFrame( + datasetWithDenseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) @@ -50,7 +50,7 @@ class LinearRegressionSuite datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept */ - datasetWithDenseFeatureWithoutIntercept = sqlContext.createDataFrame( + datasetWithDenseFeatureWithoutIntercept = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2)) @@ -59,7 +59,7 @@ class LinearRegressionSuite // When feature size is larger than 4096, normal optimizer is choosed // as the solver of linear regression in the case of "auto" mode. val featureSize = 4100 - datasetWithSparseFeature = sqlContext.createDataFrame( + datasetWithSparseFeature = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, @@ -74,7 +74,7 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df <- as.data.frame(cbind(A, b)) */ - datasetWithWeight = sqlContext.createDataFrame( + datasetWithWeight = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(19.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -90,14 +90,14 @@ class LinearRegressionSuite w <- c(1, 2, 3, 4) df.const.label <- as.data.frame(cbind(A, b.const)) */ - datasetWithWeightConstantLabel = sqlContext.createDataFrame( + datasetWithWeightConstantLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) - datasetWithWeightZeroLabel = sqlContext.createDataFrame( + datasetWithWeightZeroLabel = spark.createDataFrame( sc.parallelize(Seq( Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -828,8 +828,8 @@ class LinearRegressionSuite } val data2 = weightedSignedData ++ weightedNoiseData - (sqlContext.createDataFrame(sc.parallelize(data1, 4)), - sqlContext.createDataFrame(sc.parallelize(data2, 4))) + (spark.createDataFrame(sc.parallelize(data1, 4)), + spark.createDataFrame(sc.parallelize(data2, 4))) } val trainer1a = (new LinearRegression).setFitIntercept(true) @@ -1010,7 +1010,7 @@ class LinearRegressionSuite test("should support all NumericType labels and not support other types") { val lr = new LinearRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( - lr, isClassification = false, sqlContext) { (expected, actual) => + lr, isClassification = false, spark) { (expected, actual) => assert(expected.intercept === actual.intercept) assert(expected.coefficients === actual.coefficients) } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index ca400e1..72f3c65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -98,7 +98,7 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex test("should support all NumericType labels and not support other types") { val rf = new RandomForestRegressor().setMaxDepth(1) MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( - rf, isClassification = false, sqlContext) { (expected, actual) => + rf, isClassification = false, spark) { (expected, actual) => TreeTests.checkEqual(expected, actual) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 1d7144f..7d0e01f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -56,7 +56,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as sparse vector") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") val row1 = df.first() @@ -66,7 +66,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select as dense vector") { - val df = sqlContext.read.format("libsvm").options(Map("vectorType" -> "dense")) + val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) .load(path) assert(df.columns(0) == "label") assert(df.columns(1) == "features") @@ -78,7 +78,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("select a vector with specifying the longer dimension") { - val df = sqlContext.read.option("numFeatures", "100").format("libsvm") + val df = spark.read.option("numFeatures", "100").format("libsvm") .load(path) val row1 = df.first() val v = row1.getAs[SparseVector](1) @@ -86,27 +86,27 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } test("write libsvm data and read it again") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) - val df2 = sqlContext.read.format("libsvm").load(writepath) + val df2 = spark.read.format("libsvm").load(writepath) val row1 = df2.first() val v = row1.getAs[SparseVector](1) assert(v == Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) } test("write libsvm data failed due to invalid schema") { - val df = sqlContext.read.format("text").load(path) + val df = spark.read.format("text").load(path) intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } test("select features from libsvm relation") { - val df = sqlContext.read.format("libsvm").load(path) + val df = spark.read.format("libsvm").load(path) df.select("features").rdd.map { case Row(d: Vector) => d }.first df.select("features").collect } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala index fecf372..de92b51 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -37,8 +37,8 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext val numIterations = 20 val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2) val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2) - val trainDF = sqlContext.createDataFrame(trainRdd) - val validateDF = sqlContext.createDataFrame(validateRdd) + val trainDF = spark.createDataFrame(trainRdd) + val validateDF = spark.createDataFrame(validateRdd) val algos = Array(Regression, Regression, Classification) val losses = Array(SquaredError, AbsoluteError, LogLoss) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index e3f0989..12ade4c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.tree._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} private[ml] object TreeTests extends SparkFunSuite { @@ -42,8 +42,12 @@ private[ml] object TreeTests extends SparkFunSuite { data: RDD[LabeledPoint], categoricalFeatures: Map[Int, Int], numClasses: Int): DataFrame = { - val sqlContext = SQLContext.getOrCreate(data.sparkContext) - import sqlContext.implicits._ + val spark = SparkSession.builder + .master("local[2]") + .appName("TreeTests") + .getOrCreate() + import spark.implicits._ + val df = data.toDF() val numFeatures = data.first().features.size val featuresAttributes = Range(0, numFeatures).map { feature => http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 061d04c..85df6da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -39,7 +39,7 @@ class CrossValidatorSuite override def beforeAll(): Unit = { super.beforeAll() - dataset = sqlContext.createDataFrame( + dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } @@ -67,7 +67,7 @@ class CrossValidatorSuite } test("cross validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index df9ba41..f8d3de1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("train validation with logistic regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) val lr = new LogisticRegression @@ -58,7 +58,7 @@ class TrainValidationSplitSuite } test("train validation with linear regression") { - val dataset = sqlContext.createDataFrame( + val dataset = spark.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d9e6fd5a..4fe473b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.impl.TreeTests import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -38,17 +38,17 @@ object MLTestingUtils extends SparkFunSuite { def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( estimator: T, isClassification: Boolean, - sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + spark: SparkSession)(check: (M, M) => Unit): Unit = { val dfs = if (isClassification) { - genClassifDFWithNumericLabelCol(sqlContext) + genClassifDFWithNumericLabelCol(spark) } else { - genRegressionDFWithNumericLabelCol(sqlContext) + genRegressionDFWithNumericLabelCol(spark) } val expected = estimator.fit(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) actuals.foreach(actual => check(expected, actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", Vectors.dense(0, 2, 3), 0.0) )).toDF("label", "features", "censor") val thrown = intercept[IllegalArgumentException] { @@ -58,13 +58,13 @@ object MLTestingUtils extends SparkFunSuite { "Column label must be of type NumericType but was actually of type StringType")) } - def checkNumericTypes[T <: Evaluator](evaluator: T, sqlContext: SQLContext): Unit = { - val dfs = genEvaluatorDFWithNumericLabelCol(sqlContext, "label", "prediction") + def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { + val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) val actuals = dfs.keys.filter(_ != DoubleType).map(t => evaluator.evaluate(dfs(t))) actuals.foreach(actual => assert(expected === actual)) - val dfWithStringLabels = sqlContext.createDataFrame(Seq( + val dfWithStringLabels = spark.createDataFrame(Seq( ("0", 0d) )).toDF("label", "prediction") val thrown = intercept[IllegalArgumentException] { @@ -75,10 +75,10 @@ object MLTestingUtils extends SparkFunSuite { } def genClassifDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0, 2, 3)), (1, Vectors.dense(0, 3, 1)), (0, Vectors.dense(0, 2, 2)), @@ -95,11 +95,11 @@ object MLTestingUtils extends SparkFunSuite { } def genRegressionDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, Vectors.dense(0)), (1, Vectors.dense(1)), (2, Vectors.dense(2)), @@ -117,10 +117,10 @@ object MLTestingUtils extends SparkFunSuite { } def genEvaluatorDFWithNumericLabelCol( - sqlContext: SQLContext, + spark: SparkSession, labelColName: String = "label", predictionColName: String = "prediction"): Map[NumericType, DataFrame] = { - val df = sqlContext.createDataFrame(Seq( + val df = spark.createDataFrame(Seq( (0, 0d), (1, 1d), (2, 2d), http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 7f9e340..ba8d36f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -23,23 +23,22 @@ import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.util.TempDirectory -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.util.Utils trait MLlibTestSparkContext extends TempDirectory { self: Suite => + @transient var spark: SparkSession = _ @transient var sc: SparkContext = _ - @transient var sqlContext: SQLContext = _ @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() - val conf = new SparkConf() - .setMaster("local[2]") - .setAppName("MLlibUnitTest") - sc = new SparkContext(conf) - SQLContext.clearActive() - sqlContext = new SQLContext(sc) - SQLContext.setActive(sqlContext) + spark = SparkSession.builder + .master("local[2]") + .appName("MLlibUnitTest") + .getOrCreate() + sc = spark.sparkContext + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString sc.setCheckpointDir(checkpointDir) } @@ -47,12 +46,11 @@ trait MLlibTestSparkContext extends TempDirectory { self: Suite => override def afterAll() { try { Utils.deleteRecursively(new File(checkpointDir)) - sqlContext = null SQLContext.clearActive() - if (sc != null) { - sc.stop() + if (spark != null) { + spark.stop() } - sc = null + spark = null } finally { super.afterAll() } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 189cc39..f2ae40e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -28,14 +28,13 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -44,21 +43,22 @@ import org.apache.spark.sql.types.StructType; // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { - private transient JavaSparkContext javaCtx; - private transient SQLContext sqlContext; + private transient SparkSession spark; + private transient JavaSparkContext jsc; @Before public void setUp() { - SparkContext context = new SparkContext("local[*]", "testing"); - javaCtx = new JavaSparkContext(context); - sqlContext = new SQLContext(context); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + jsc = new JavaSparkContext(spark.sparkContext()); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - javaCtx = null; + spark.stop(); + spark = null; } public static class Person implements Serializable { @@ -94,7 +94,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) throws Exception { @@ -107,9 +107,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<Row> actual = sqlContext.sql("SELECT * FROM people").collectAsList(); + List<Row> actual = spark.sql("SELECT * FROM people").collectAsList(); List<Row> expected = new ArrayList<>(2); expected.add(RowFactory.create("Michael", 29)); @@ -130,7 +130,7 @@ public class JavaApplySchemaSuite implements Serializable { person2.setAge(28); personList.add(person2); - JavaRDD<Row> rowRDD = javaCtx.parallelize(personList).map( + JavaRDD<Row> rowRDD = jsc.parallelize(personList).map( new Function<Person, Row>() { @Override public Row call(Person person) { @@ -143,9 +143,9 @@ public class JavaApplySchemaSuite implements Serializable { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - Dataset<Row> df = sqlContext.createDataFrame(rowRDD, schema); + Dataset<Row> df = spark.createDataFrame(rowRDD, schema); df.registerTempTable("people"); - List<String> actual = sqlContext.sql("SELECT * FROM people").toJavaRDD() + List<String> actual = spark.sql("SELECT * FROM people").toJavaRDD() .map(new Function<Row, String>() { @Override public String call(Row row) { @@ -162,7 +162,7 @@ public class JavaApplySchemaSuite implements Serializable { @Test public void applySchemaToJSON() { - JavaRDD<String> jsonRDD = javaCtx.parallelize(Arrays.asList( + JavaRDD<String> jsonRDD = jsc.parallelize(Arrays.asList( "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + "\"boolean\":true, \"null\":null}", @@ -199,18 +199,18 @@ public class JavaApplySchemaSuite implements Serializable { null, "this is another simple string.")); - Dataset<Row> df1 = sqlContext.read().json(jsonRDD); + Dataset<Row> df1 = spark.read().json(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List<Row> actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); + List<Row> actual1 = spark.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - Dataset<Row> df2 = sqlContext.read().schema(expectedSchema).json(jsonRDD); + Dataset<Row> df2 = spark.read().schema(expectedSchema).json(jsonRDD); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List<Row> actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); + List<Row> actual2 = spark.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 1eb680d..324ebba 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -20,12 +20,7 @@ package test.org.apache.spark.sql; import java.io.Serializable; import java.net.URISyntaxException; import java.net.URL; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.ArrayList; +import java.util.*; import scala.collection.JavaConverters; import scala.collection.Seq; @@ -34,46 +29,45 @@ import com.google.common.collect.ImmutableMap; import com.google.common.primitives.Ints; import org.junit.*; -import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; +import org.apache.spark.util.sketch.BloomFilter; import org.apache.spark.util.sketch.CountMinSketch; import static org.apache.spark.sql.functions.*; import static org.apache.spark.sql.types.DataTypes.*; -import org.apache.spark.util.sketch.BloomFilter; public class JavaDataFrameSuite { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } @Test public void testExecution() { - Dataset<Row> df = context.table("testData").filter("key = 1"); + Dataset<Row> df = spark.table("testData").filter("key = 1"); Assert.assertEquals(1, df.select("key").collectAsList().get(0).get(0)); } @Test public void testCollectAndTake() { - Dataset<Row> df = context.table("testData").filter("key = 1 or key = 2 or key = 3"); + Dataset<Row> df = spark.table("testData").filter("key = 1 or key = 2 or key = 3"); Assert.assertEquals(3, df.select("key").collectAsList().size()); Assert.assertEquals(2, df.select("key").takeAsList(2).size()); } @@ -83,7 +77,7 @@ public class JavaDataFrameSuite { */ @Test public void testVarargMethods() { - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.toDF("key1", "value1"); @@ -112,7 +106,7 @@ public class JavaDataFrameSuite { df.select(coalesce(col("key"))); // Varargs with mathfunctions - Dataset<Row> df2 = context.table("testData2"); + Dataset<Row> df2 = spark.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); @@ -126,7 +120,7 @@ public class JavaDataFrameSuite { @Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly - Dataset<Row> df = context.table("testData"); + Dataset<Row> df = spark.table("testData"); df.show(); df.show(1000); } @@ -194,7 +188,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); - Dataset<Row> df = context.createDataFrame(data, Bean.class); + Dataset<Row> df = spark.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -202,7 +196,7 @@ public class JavaDataFrameSuite { public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); - Dataset<Row> df = context.createDataFrame(rdd, Bean.class); + Dataset<Row> df = spark.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); } @@ -210,7 +204,7 @@ public class JavaDataFrameSuite { public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); - Dataset<Row> df = context.createDataFrame(rows, schema); + Dataset<Row> df = spark.createDataFrame(rows, schema); List<Row> result = df.collectAsList(); Assert.assertEquals(1, result.size()); } @@ -239,7 +233,7 @@ public class JavaDataFrameSuite { @Test public void testCrosstab() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Dataset<Row> crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); @@ -258,7 +252,7 @@ public class JavaDataFrameSuite { @Test public void testFrequentItems() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); String[] cols = {"a"}; Dataset<Row> results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collectAsList().get(0).getSeq(0).contains(1)); @@ -266,21 +260,21 @@ public class JavaDataFrameSuite { @Test public void testCorrelation() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double pearsonCorr = df.stat().corr("a", "b", "pearson"); Assert.assertTrue(Math.abs(pearsonCorr) < 1.0e-6); } @Test public void testCovariance() { - Dataset<Row> df = context.table("testData2"); + Dataset<Row> df = spark.table("testData2"); Double result = df.stat().cov("a", "b"); Assert.assertTrue(Math.abs(result) < 1.0e-6); } @Test public void testSampleBy() { - Dataset<Row> df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); + Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); @@ -291,7 +285,7 @@ public class JavaDataFrameSuite { @Test public void pivot() { - Dataset<Row> df = context.table("courseSales"); + Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); @@ -324,10 +318,10 @@ public class JavaDataFrameSuite { @Test public void testGenericLoad() { - Dataset<Row> df1 = context.read().format("text").load(getResource("text-suite.txt")); + Dataset<Row> df1 = spark.read().format("text").load(getResource("text-suite.txt")); Assert.assertEquals(4L, df1.count()); - Dataset<Row> df2 = context.read().format("text").load( + Dataset<Row> df2 = spark.read().format("text").load( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, df2.count()); @@ -335,10 +329,10 @@ public class JavaDataFrameSuite { @Test public void testTextLoad() { - Dataset<String> ds1 = context.read().text(getResource("text-suite.txt")); + Dataset<String> ds1 = spark.read().text(getResource("text-suite.txt")); Assert.assertEquals(4L, ds1.count()); - Dataset<String> ds2 = context.read().text( + Dataset<String> ds2 = spark.read().text( getResource("text-suite.txt"), getResource("text-suite2.txt")); Assert.assertEquals(5L, ds2.count()); @@ -346,7 +340,7 @@ public class JavaDataFrameSuite { @Test public void testCountMinSketch() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); @@ -371,7 +365,7 @@ public class JavaDataFrameSuite { @Test public void testBloomFilter() { - Dataset<Long> df = context.range(1000); + Dataset<Long> df = spark.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f1b1c22..8354a5b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,46 +23,43 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; -import com.google.common.base.Objects; -import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; import scala.Tuple5; +import com.google.common.base.Objects; import org.junit.*; +import org.junit.rules.ExpectedException; import org.apache.spark.Accumulator; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; -import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.StructType; - -import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.functions.col; +import static org.apache.spark.sql.functions.expr; import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { + private transient TestSparkSession spark; private transient JavaSparkContext jsc; - private transient TestSQLContext context; @Before public void setUp() { // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); + spark = new TestSparkSession(); + jsc = new JavaSparkContext(spark.sparkContext()); + spark.loadTestData(); } @After public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; + spark.stop(); + spark = null; } private <T1, T2> Tuple2<T1, T2> tuple2(T1 t1, T2 t2) { @@ -72,7 +69,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); } @@ -80,7 +77,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); } @@ -88,7 +85,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testToLocalIterator() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Iterator<String> iter = ds.toLocalIterator(); Assert.assertEquals("hello", iter.next()); Assert.assertEquals("world", iter.next()); @@ -98,7 +95,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter(new FilterFunction<String>() { @@ -149,7 +146,7 @@ public class JavaDatasetSuite implements Serializable { public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); ds.foreach(new ForeachFunction<String>() { @Override @@ -163,7 +160,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); int reduced = ds.reduce(new ReduceFunction<Integer>() { @Override @@ -177,7 +174,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); KeyValueGroupedDataset<Integer, String> grouped = ds.groupByKey( new MapFunction<String, Integer>() { @Override @@ -227,7 +224,7 @@ public class JavaDatasetSuite implements Serializable { toSet(reduced.collectAsList())); List<Integer> data2 = Arrays.asList(2, 6, 10); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()); KeyValueGroupedDataset<Integer, Integer> grouped2 = ds2.groupByKey( new MapFunction<Integer, Integer>() { @Override @@ -261,7 +258,7 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), @@ -275,12 +272,12 @@ public class JavaDatasetSuite implements Serializable { @Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList())); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); - Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); + Dataset<String> ds2 = spark.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); @@ -307,9 +304,9 @@ public class JavaDatasetSuite implements Serializable { @Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); - Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); + Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); - Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); + Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); @@ -322,21 +319,21 @@ public class JavaDatasetSuite implements Serializable { public void testTupleEncoder() { Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); - Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); + Dataset<Tuple2<Integer, String>> ds2 = spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder<Tuple3<Integer, Long, String>> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<>(1, 2L, "a")); - Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); + Dataset<Tuple3<Integer, Long, String>> ds3 = spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); - Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); + Dataset<Tuple4<Integer, String, Long, String>> ds4 = spark.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = @@ -345,7 +342,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = - context.createDataset(data5, encoder5); + spark.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); } @@ -356,7 +353,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); - Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); + Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) @@ -366,7 +363,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = - context.createDataset(data2, encoder2); + spark.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) @@ -376,7 +373,7 @@ public class JavaDatasetSuite implements Serializable { List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = - context.createDataset(data3, encoder3); + spark.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); } @@ -390,7 +387,7 @@ public class JavaDatasetSuite implements Serializable { 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds = - context.createDataset(data, encoder); + spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -441,7 +438,7 @@ public class JavaDatasetSuite implements Serializable { Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class); List<KryoSerializable> data = Arrays.asList( new KryoSerializable("hello"), new KryoSerializable("world")); - Dataset<KryoSerializable> ds = context.createDataset(data, encoder); + Dataset<KryoSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @@ -450,14 +447,14 @@ public class JavaDatasetSuite implements Serializable { Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class); List<JavaSerializable> data = Arrays.asList( new JavaSerializable("hello"), new JavaSerializable("world")); - Dataset<JavaSerializable> ds = context.createDataset(data, encoder); + Dataset<JavaSerializable> ds = spark.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); } @Test public void testRandomSplit() { List<String> data = Arrays.asList("hello", "world", "from", "spark"); - Dataset<String> ds = context.createDataset(data, Encoders.STRING()); + Dataset<String> ds = spark.createDataset(data, Encoders.STRING()); double[] arraySplit = {1, 2, 3}; List<Dataset<String>> randomSplit = ds.randomSplitAsList(arraySplit, 1); @@ -647,14 +644,14 @@ public class JavaDatasetSuite implements Serializable { obj2.setF(Arrays.asList(300L, null, 400L)); List<SimpleJavaBean> data = Arrays.asList(obj1, obj2); - Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Dataset<SimpleJavaBean> ds = spark.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List<NestedJavaBean> data2 = Arrays.asList(obj3); - Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Dataset<NestedJavaBean> ds2 = spark.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow(new Object[]{ @@ -678,7 +675,7 @@ public class JavaDatasetSuite implements Serializable { .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); - Dataset<SimpleJavaBean> ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } @@ -692,7 +689,7 @@ public class JavaDatasetSuite implements Serializable { obj.setB(new Date(0)); obj.setC(java.math.BigDecimal.valueOf(1)); Dataset<SimpleJavaBean2> ds = - context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + spark.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } @@ -776,7 +773,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); SmallBean smallBean = new SmallBean(); @@ -793,7 +790,7 @@ public class JavaDatasetSuite implements Serializable { { Row row = new GenericRow(new Object[] { null }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); NestedSmallBean nestedSmallBean = new NestedSmallBean(); @@ -810,7 +807,7 @@ public class JavaDatasetSuite implements Serializable { }) }); - Dataset<Row> df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<Row> df = spark.createDataFrame(Collections.singletonList(row), schema); Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); ds.collect(); http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 4a78dca..2274912 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -24,33 +24,30 @@ import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import org.apache.spark.SparkContext; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.api.java.UDF2; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. public class JavaUDFSuite implements Serializable { - private transient JavaSparkContext sc; - private transient SQLContext sqlContext; + private transient SparkSession spark; @Before public void setUp() { - SparkContext _sc = new SparkContext("local[*]", "testing"); - sqlContext = new SQLContext(_sc); - sc = new JavaSparkContext(_sc); + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); } @After public void tearDown() { - sqlContext.sparkContext().stop(); - sqlContext = null; - sc = null; + spark.stop(); + spark = null; } @SuppressWarnings("unchecked") @@ -60,14 +57,14 @@ public class JavaUDFSuite implements Serializable { // sqlContext.registerFunction( // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF1<String, Integer>() { + spark.udf().register("stringLengthTest", new UDF1<String, Integer>() { @Override public Integer call(String str) { return str.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); + Row result = spark.sql("SELECT stringLengthTest('test')").head(); Assert.assertEquals(4, result.getInt(0)); } @@ -80,14 +77,14 @@ public class JavaUDFSuite implements Serializable { // (String str1, String str2) -> str1.length() + str2.length, // DataType.IntegerType); - sqlContext.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { + spark.udf().register("stringLengthTest", new UDF2<String, String, Integer>() { @Override public Integer call(String str1, String str2) { return str1.length() + str2.length(); } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org