Repository: spark Updated Branches: refs/heads/master c1b62e420 -> 3bcb1b481
Revert "[SPARK-24776][SQL] Avro unit test: use SQLTestUtils and replace deprecated methods" This reverts commit c1b62e420a43aa7da36733ccdbec057d87ac1b43. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3bcb1b48 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3bcb1b48 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3bcb1b48 Branch: refs/heads/master Commit: 3bcb1b481423aedf1ac531ad582c7cb8685f1e3c Parents: c1b62e4 Author: Xiao Li <gatorsm...@gmail.com> Authored: Fri Jul 13 10:06:26 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Fri Jul 13 10:06:26 2018 -0700 ---------------------------------------------------------------------- .../org/apache/spark/sql/avro/AvroSuite.scala | 114 +++++++------- .../org/apache/spark/sql/avro/TestUtils.scala | 156 +++++++++++++++++++ 2 files changed, 217 insertions(+), 53 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3bcb1b48/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 108b347..c6c1e40 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -31,24 +31,32 @@ import org.apache.avro.generic.{GenericData, GenericDatumWriter, GenericRecord} import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed} import org.apache.commons.io.FileUtils +import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.avro.SchemaConverters.IncompatibleSchemaException -import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.sql.types._ -class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { +class AvroSuite extends SparkFunSuite { val episodesFile = "src/test/resources/episodes.avro" val testFile = "src/test/resources/test.avro" + private var spark: SparkSession = _ + override protected def beforeAll(): Unit = { super.beforeAll() - spark.conf.set("spark.sql.files.maxPartitionBytes", 1024) - } - - def checkReloadMatchesSaved(originalFile: String, newFile: String): Unit = { - val originalEntries = spark.read.avro(testFile).collect() - val newEntries = spark.read.avro(newFile) - checkAnswer(newEntries, originalEntries) + spark = SparkSession.builder() + .master("local[2]") + .appName("AvroSuite") + .config("spark.sql.files.maxPartitionBytes", 1024) + .getOrCreate() + } + + override protected def afterAll(): Unit = { + try { + spark.sparkContext.stop() + } finally { + super.afterAll() + } } test("reading from multiple paths") { @@ -60,7 +68,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { val df = spark.read.avro(episodesFile) val fields = List("title", "air_date", "doctor") for (field <- fields) { - withTempPath { dir => + TestUtils.withTempDir { dir => val outputDir = s"$dir/${UUID.randomUUID}" df.write.partitionBy(field).avro(outputDir) val input = spark.read.avro(outputDir) @@ -74,12 +82,12 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("request no fields") { val df = spark.read.avro(episodesFile) - df.createOrReplaceTempView("avro_table") + df.registerTempTable("avro_table") assert(spark.sql("select count(*) from avro_table").collect().head === Row(8)) } test("convert formats") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.write.parquet(dir.getCanonicalPath) assert(spark.read.parquet(dir.getCanonicalPath).count() === df.count) @@ -87,16 +95,15 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("rearrange internal schema") { - withTempPath { dir => + TestUtils.withTempDir { dir => val df = spark.read.avro(episodesFile) df.select("doctor", "title").write.avro(dir.getCanonicalPath) } } test("test NULL avro type") { - withTempPath { dir => - val fields = - Seq(new Field("null", Schema.create(Type.NULL), "doc", null.asInstanceOf[Any])).asJava + TestUtils.withTempDir { dir => + val fields = Seq(new Field("null", Schema.create(Type.NULL), "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) val datumWriter = new GenericDatumWriter[GenericRecord](schema) @@ -115,11 +122,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(int, long) is read as long") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.INT), Schema.create(Type.LONG)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -143,11 +150,11 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double) is read as double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion(List(Schema.create(Type.FLOAT), Schema.create(Type.DOUBLE)).asJava) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -171,7 +178,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("union(float, double, null) is read as nullable double") { - withTempPath { dir => + TestUtils.withTempDir { dir => val avroSchema: Schema = { val union = Schema.createUnion( List(Schema.create(Type.FLOAT), @@ -179,7 +186,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { Schema.create(Type.NULL) ).asJava ) - val fields = Seq(new Field("field1", union, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", union, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) schema @@ -203,9 +210,9 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Union of a single type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val UnionOfOne = Schema.createUnion(List(Schema.create(Type.INT)).asJava) - val fields = Seq(new Field("field1", UnionOfOne, "doc", null.asInstanceOf[Any])).asJava + val fields = Seq(new Field("field1", UnionOfOne, "doc", null)).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -226,16 +233,16 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Complex Union Type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val fixedSchema = Schema.createFixed("fixed_name", "doc", "namespace", 4) val enumSchema = Schema.createEnum("enum_name", "doc", "namespace", List("e1", "e2").asJava) val complexUnionType = Schema.createUnion( List(Schema.create(Type.INT), Schema.create(Type.STRING), fixedSchema, enumSchema).asJava) val fields = Seq( - new Field("field1", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field2", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field3", complexUnionType, "doc", null.asInstanceOf[Any]), - new Field("field4", complexUnionType, "doc", null.asInstanceOf[Any]) + new Field("field1", complexUnionType, "doc", null), + new Field("field2", complexUnionType, "doc", null), + new Field("field3", complexUnionType, "doc", null), + new Field("field4", complexUnionType, "doc", null) ).asJava val schema = Schema.createRecord("name", "docs", "namespace", false) schema.setFields(fields) @@ -264,7 +271,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Lots of nulls") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("binary", BinaryType, true), StructField("timestamp", TimestampType, true), @@ -283,7 +290,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Struct field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("short", ShortType, true), @@ -302,7 +309,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Date field type") { - withTempPath { dir => + TestUtils.withTempDir { dir => val schema = StructType(Seq( StructField("float", FloatType, true), StructField("date", DateType, true) @@ -322,7 +329,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("Array data types") { - withTempPath { dir => + TestUtils.withTempDir { dir => val testSchema = StructType(Seq( StructField("byte_array", ArrayType(ByteType), true), StructField("short_array", ArrayType(ShortType), true), @@ -356,12 +363,13 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("write with compression") { - withTempPath { dir => + TestUtils.withTempDir { dir => val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec" val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level" val uncompressDir = s"$dir/uncompress" val deflateDir = s"$dir/deflate" val snappyDir = s"$dir/snappy" + val fakeDir = s"$dir/fake" val df = spark.read.avro(testFile) spark.conf.set(AVRO_COMPRESSION_CODEC, "uncompressed") @@ -431,7 +439,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("sql test") { spark.sql( s""" - |CREATE TEMPORARY VIEW avroTable + |CREATE TEMPORARY TABLE avroTable |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) @@ -442,24 +450,24 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("conversion to avro and back") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { dir => + TestUtils.withTempDir { dir => val avroDir = s"$dir/avro" spark.read.avro(testFile).write.avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) } } test("conversion to avro and back with namespace") { // Note that test.avro includes a variety of types, some of which are nullable. We expect to // get the same values back. - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val name = "AvroTest" val namespace = "com.databricks.spark.avro" val parameters = Map("recordName" -> name, "recordNamespace" -> namespace) val avroDir = tempDir + "/namedAvro" spark.read.avro(testFile).write.options(parameters).avro(avroDir) - checkReloadMatchesSaved(testFile, avroDir) + TestUtils.checkReloadMatchesSaved(spark, testFile, avroDir) // Look at raw file and make sure has namespace info val rawSaved = spark.sparkContext.textFile(avroDir) @@ -470,7 +478,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("converting some specific sparkSQL types to avro") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val testSchema = StructType(Seq( StructField("Name", StringType, false), StructField("Length", IntegerType, true), @@ -512,7 +520,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("correctly read long as date/timestamp type") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -541,7 +549,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("does not coerce null date/timestamp value to 0 epoch.") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val sparkSession = spark import sparkSession.implicits._ @@ -602,7 +610,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { // Directory given has no avro files intercept[AnalysisException] { - withTempPath(dir => spark.read.avro(dir.getCanonicalPath)) + TestUtils.withTempDir(dir => spark.read.avro(dir.getCanonicalPath)) } intercept[AnalysisException] { @@ -616,7 +624,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } intercept[FileNotFoundException] { - withTempPath { dir => + TestUtils.withTempDir { dir => FileUtils.touch(new File(dir, "test")) spark.read.avro(dir.toString) } @@ -625,19 +633,19 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("SQL test insert overwrite") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val tempEmptyDir = s"$tempDir/sqlOverwrite" // Create a temp directory for table that will be overwritten new File(tempEmptyDir).mkdirs() spark.sql( s""" - |CREATE TEMPORARY VIEW episodes + |CREATE TEMPORARY TABLE episodes |USING avro |OPTIONS (path "$episodesFile") """.stripMargin.replaceAll("\n", " ")) spark.sql( s""" - |CREATE TEMPORARY VIEW episodesEmpty + |CREATE TEMPORARY TABLE episodesEmpty |(name string, air_date string, doctor int) |USING avro |OPTIONS (path "$tempEmptyDir") @@ -657,7 +665,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test save and load") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -671,7 +679,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { test("test load with non-Avro file") { // Test if load works as expected - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => val df = spark.read.avro(episodesFile) assert(df.count == 8) @@ -729,7 +737,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } test("read avro file partitioned") { - withTempPath { dir => + TestUtils.withTempDir { dir => val sparkSession = spark import sparkSession.implicits._ val df = (0 to 1024 * 3).toDS.map(i => s"record${i}").toDF("records") @@ -748,7 +756,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTop(id: Int, data: NestedMiddle) test("saving avro that has nested records with the same name") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame(List(NestedTop(1, NestedMiddle(2, NestedBottom(3, "1"))))) val outputFolder = s"$tempDir/duplicate_names/" @@ -765,7 +773,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopArray(id: Int, data: NestedMiddleArray) test("saving avro that has nested records with the same name inside an array") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopArray(1, NestedMiddleArray(2, Array( @@ -786,7 +794,7 @@ class AvroSuite extends QueryTest with SharedSQLContext with SQLTestUtils { case class NestedTopMap(id: Int, data: NestedMiddleMap) test("saving avro that has nested records with the same name inside a map") { - withTempPath { tempDir => + TestUtils.withTempDir { tempDir => // Save avro file on output folder path val writeDf = spark.createDataFrame( List(NestedTopMap(1, NestedMiddleMap(2, Map( http://git-wip-us.apache.org/repos/asf/spark/blob/3bcb1b48/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala ---------------------------------------------------------------------- diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala new file mode 100755 index 0000000..4ae9b14 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/TestUtils.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.{File, IOException} +import java.nio.ByteBuffer + +import scala.collection.immutable.HashSet +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import com.google.common.io.Files +import java.util + +import org.apache.spark.sql.SparkSession + +private[avro] object TestUtils { + + /** + * This function checks that all records in a file match the original + * record. + */ + def checkReloadMatchesSaved(spark: SparkSession, testFile: String, avroDir: String): Unit = { + + def convertToString(elem: Any): String = { + elem match { + case null => "NULL" // HashSets can't have null in them, so we use a string instead + case arrayBuf: ArrayBuffer[_] => + arrayBuf.asInstanceOf[ArrayBuffer[Any]].toArray.deep.mkString(" ") + case arrayByte: Array[Byte] => arrayByte.deep.mkString(" ") + case other => other.toString + } + } + + val originalEntries = spark.read.avro(testFile).collect() + val newEntries = spark.read.avro(avroDir).collect() + + assert(originalEntries.length == newEntries.length) + + val origEntrySet = Array.fill(originalEntries(0).size)(new HashSet[Any]()) + for (origEntry <- originalEntries) { + var idx = 0 + for (origElement <- origEntry.toSeq) { + origEntrySet(idx) += convertToString(origElement) + idx += 1 + } + } + + for (newEntry <- newEntries) { + var idx = 0 + for (newElement <- newEntry.toSeq) { + assert(origEntrySet(idx).contains(convertToString(newElement))) + idx += 1 + } + } + } + + def withTempDir(f: File => Unit): Unit = { + val dir = Files.createTempDir() + dir.delete() + try f(dir) finally deleteRecursively(dir) + } + + /** + * This function deletes a file or a directory with everything that's in it. This function is + * copied from Spark with minor modifications made to it. See original source at: + * github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/util/Utils.scala + */ + + def deleteRecursively(file: File) { + def listFilesSafely(file: File): Seq[File] = { + if (file.exists()) { + val files = file.listFiles() + if (files == null) { + throw new IOException("Failed to list files for dir: " + file) + } + files + } else { + List() + } + } + + if (file != null) { + try { + if (file.isDirectory) { + var savedIOException: IOException = null + for (child <- listFilesSafely(file)) { + try { + deleteRecursively(child) + } catch { + // In case of multiple exceptions, only last one will be thrown + case ioe: IOException => savedIOException = ioe + } + } + if (savedIOException != null) { + throw savedIOException + } + } + } finally { + if (!file.delete()) { + // Delete can also fail if the file simply did not exist + if (file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath) + } + } + } + } + } + + /** + * This function generates a random map(string, int) of a given size. + */ + private[avro] def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = { + val jMap = new util.HashMap[String, Int]() + for (i <- 0 until size) { + jMap.put(rand.nextString(5), i) + } + jMap + } + + /** + * This function generates a random array of booleans of a given size. + */ + private[avro] def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = { + val vec = new util.ArrayList[Boolean]() + for (i <- 0 until size) { + vec.add(rand.nextBoolean()) + } + vec + } + + /** + * This function generates a random ByteBuffer of a given size. + */ + private[avro] def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = { + val bb = ByteBuffer.allocate(size) + val arrayOfBytes = new Array[Byte](size) + rand.nextBytes(arrayOfBytes) + bb.put(arrayOfBytes) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org