Repository: spark
Updated Branches:
  refs/heads/master 56f501e1c -> 3121b411f


[SPARK-23846][SQL] The samplingRatio option for CSV datasource

## What changes were proposed in this pull request?

I propose to support the `samplingRatio` option for schema inferring of CSV 
datasource similar to the same option of JSON datasource:
https://github.com/apache/spark/blob/b14993e1fcb68e1c946a671c6048605ab4afdf58/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala#L49-L50

## How was this patch tested?

Added 2 tests for json and 2 tests for csv datasources. The tests checks that 
only subset of input dataset is used for schema inferring.

Author: Maxim Gekk <[email protected]>
Author: Maxim Gekk <[email protected]>

Closes #20959 from MaxGekk/csv-sampling.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3121b411
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3121b411
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3121b411

Branch: refs/heads/master
Commit: 3121b411f748859ed3ed1c97cbc21e6ae980a35c
Parents: 56f501e
Author: Maxim Gekk <[email protected]>
Authored: Mon Apr 30 09:45:22 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Mon Apr 30 09:45:22 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/readwriter.py                |  7 ++-
 python/pyspark/sql/tests.py                     |  7 +++
 .../org/apache/spark/sql/DataFrameReader.scala  |  1 +
 .../datasources/csv/CSVDataSource.scala         |  6 ++-
 .../execution/datasources/csv/CSVOptions.scala  |  3 ++
 .../execution/datasources/csv/CSVUtils.scala    | 28 ++++++++++++
 .../execution/datasources/csv/CSVSuite.scala    | 47 +++++++++++++++++++-
 .../execution/datasources/csv/TestCsvData.scala | 36 +++++++++++++++
 8 files changed, 129 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/python/pyspark/sql/readwriter.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 6811fa6..9899eb5 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -345,7 +345,8 @@ class DataFrameReader(OptionUtils):
             ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, 
positiveInf=None,
             negativeInf=None, dateFormat=None, timestampFormat=None, 
maxColumns=None,
             maxCharsPerColumn=None, maxMalformedLogPerPartition=None, 
mode=None,
-            columnNameOfCorruptRecord=None, multiLine=None, 
charToEscapeQuoteEscaping=None):
+            columnNameOfCorruptRecord=None, multiLine=None, 
charToEscapeQuoteEscaping=None,
+            samplingRatio=None):
         """Loads a CSV file and returns the result as a  :class:`DataFrame`.
 
         This function will go through the input once to determine the input 
schema if
@@ -428,6 +429,8 @@ class DataFrameReader(OptionUtils):
                                           the quote character. If None is set, 
the default value is
                                           escape character when escape and 
quote characters are
                                           different, ``\0`` otherwise.
+        :param samplingRatio: defines fraction of rows used for schema 
inferring.
+                              If None is set, it uses the default value, 
``1.0``.
 
         >>> df = spark.read.csv('python/test_support/sql/ages.csv')
         >>> df.dtypes
@@ -446,7 +449,7 @@ class DataFrameReader(OptionUtils):
             maxCharsPerColumn=maxCharsPerColumn,
             maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode,
             columnNameOfCorruptRecord=columnNameOfCorruptRecord, 
multiLine=multiLine,
-            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping)
+            charToEscapeQuoteEscaping=charToEscapeQuoteEscaping, 
samplingRatio=samplingRatio)
         if isinstance(path, basestring):
             path = [path]
         if type(path) == list:

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index e0cd2aa..bc3eaf1 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3033,6 +3033,13 @@ class SQLTests(ReusedSQLTestCase):
             .json(rdd).schema
         self.assertEquals(schema, StructType([StructField("a", LongType(), 
True)]))
 
+    def test_csv_sampling_ratio(self):
+        rdd = self.spark.sparkContext.range(0, 100, 1, 1) \
+            .map(lambda x: '0.1' if x == 1 else str(x))
+        schema = self.spark.read.option('inferSchema', True)\
+            .csv(rdd, samplingRatio=0.5).schema
+        self.assertEquals(schema, StructType([StructField("_c0", 
IntegerType(), True)]))
+
 
 class HiveSparkSubmitTests(SparkSubmitTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 6b2ea6c..53f4488 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -539,6 +539,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * <li>`header` (default `false`): uses the first line as names of 
columns.</li>
    * <li>`inferSchema` (default `false`): infers the input schema 
automatically from data. It
    * requires one extra pass over the data.</li>
+   * <li>`samplingRatio` (default is 1.0): defines fraction of rows used for 
schema inferring.</li>
    * <li>`ignoreLeadingWhiteSpace` (default `false`): a flag indicating 
whether or not leading
    * whitespaces from values being read should be skipped.</li>
    * <li>`ignoreTrailingWhiteSpace` (default `false`): a flag indicating 
whether or not trailing

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index 4870d75..bc1f4ab 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -161,7 +161,8 @@ object TextInputCSVDataSource extends CSVDataSource {
       val firstRow = new 
CsvParser(parsedOptions.asParserSettings).parseLine(firstLine)
       val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
       val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
-      val tokenRDD = csv.rdd.mapPartitions { iter =>
+      val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
+      val tokenRDD = sampled.rdd.mapPartitions { iter =>
         val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions)
         val linesWithoutHeader =
           CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions)
@@ -235,7 +236,8 @@ object MultiLineCSVDataSource extends CSVDataSource {
             parsedOptions.headerFlag,
             new CsvParser(parsedOptions.asParserSettings))
         }
-        CSVInferSchema.infer(tokenRDD, header, parsedOptions)
+        val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
+        CSVInferSchema.infer(sampled, header, parsedOptions)
       case None =>
         // If the first row could not be read, just return the empty schema.
         StructType(Nil)

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index c167906..2ec0fc6 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -150,6 +150,9 @@ class CSVOptions(
 
   val isCommentSet = this.comment != '\u0000'
 
+  val samplingRatio =
+    parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
+
   def asWriterSettings: CsvWriterSettings = {
     val writerSettings = new CsvWriterSettings()
     val format = writerSettings.getFormat

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 72b053d..31464f1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -17,7 +17,10 @@
 
 package org.apache.spark.sql.execution.datasources.csv
 
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.catalyst.json.JSONOptions
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 
@@ -131,4 +134,29 @@ object CSVUtils {
     schema.foreach(field => verifyType(field.dataType))
   }
 
+  /**
+   * Sample CSV dataset as configured by `samplingRatio`.
+   */
+  def sample(csv: Dataset[String], options: CSVOptions): Dataset[String] = {
+    require(options.samplingRatio > 0,
+      s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+    if (options.samplingRatio > 0.99) {
+      csv
+    } else {
+      csv.sample(withReplacement = false, options.samplingRatio, 1)
+    }
+  }
+
+  /**
+   * Sample CSV RDD as configured by `samplingRatio`.
+   */
+  def sample(csv: RDD[Array[String]], options: CSVOptions): RDD[Array[String]] 
= {
+    require(options.samplingRatio > 0,
+      s"samplingRatio (${options.samplingRatio}) should be greater than 0")
+    if (options.samplingRatio > 0.99) {
+      csv
+    } else {
+      csv.sample(withReplacement = false, options.samplingRatio, 1)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 4398e54..461abdd 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -30,12 +30,11 @@ import org.apache.hadoop.io.compress.GzipCodec
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, UDT}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.functions.{col, regexp_replace}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
 import org.apache.spark.sql.types._
 
-class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
+class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with 
TestCsvData {
   import testImplicits._
 
   private val carsFile = "test-data/cars.csv"
@@ -1279,4 +1278,48 @@ class CSVSuite extends QueryTest with SharedSQLContext 
with SQLTestUtils {
       Row("0,2013-111-11 12:13:14") :: Row(null) :: Nil
     )
   }
+
+  test("SPARK-23846: schema inferring touches less data if samplingRatio < 
1.0") {
+    // Set default values for the DataSource parameters to make sure
+    // that whole test file is mapped to only one partition. This will 
guarantee
+    // reliable sampling of the input file.
+    withSQLConf(
+      "spark.sql.files.maxPartitionBytes" -> (128 * 1024 * 1024).toString,
+      "spark.sql.files.openCostInBytes" -> (4 * 1024 * 1024).toString
+    )(withTempPath { path =>
+      val ds = sampledTestData.coalesce(1)
+      ds.write.text(path.getAbsolutePath)
+
+      val readback = spark.read
+        .option("inferSchema", true).option("samplingRatio", 0.1)
+        .csv(path.getCanonicalPath)
+      assert(readback.schema == new StructType().add("_c0", IntegerType))
+    })
+  }
+
+  test("SPARK-23846: usage of samplingRatio while parsing a dataset of 
strings") {
+    val ds = sampledTestData.coalesce(1)
+    val readback = spark.read
+      .option("inferSchema", true).option("samplingRatio", 0.1)
+      .csv(ds)
+
+    assert(readback.schema == new StructType().add("_c0", IntegerType))
+  }
+
+  test("SPARK-23846: samplingRatio is out of the range (0, 1.0]") {
+    val ds = spark.range(0, 100, 1, 1).map(_.toString)
+
+    val errorMsg0 = intercept[IllegalArgumentException] {
+      spark.read.option("inferSchema", true).option("samplingRatio", 
-1).csv(ds)
+    }.getMessage
+    assert(errorMsg0.contains("samplingRatio (-1.0) should be greater than 0"))
+
+    val errorMsg1 = intercept[IllegalArgumentException] {
+      spark.read.option("inferSchema", true).option("samplingRatio", 0).csv(ds)
+    }.getMessage
+    assert(errorMsg1.contains("samplingRatio (0.0) should be greater than 0"))
+
+    val sampled = spark.read.option("inferSchema", 
true).option("samplingRatio", 1.0).csv(ds)
+    assert(sampled.count() == ds.count())
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3121b411/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala
new file mode 100644
index 0000000..3e20cc4
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/TestCsvData.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
+
+private[csv] trait TestCsvData {
+  protected def spark: SparkSession
+
+  def sampledTestData: Dataset[String] = {
+    spark.range(0, 100, 1).map { index =>
+      val predefinedSample = Set[Long](2, 8, 15, 27, 30, 34, 35, 37, 44, 46,
+        57, 62, 68, 72)
+      if (predefinedSample.contains(index)) {
+        index.toString
+      } else {
+        (index.toDouble + 0.1).toString
+      }
+    }(Encoders.STRING)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to