Repository: spark
Updated Branches:
  refs/heads/master 6df8e3886 -> b99129cc4


[SPARK-15982][SPARK-16009][SPARK-16007][SQL] Harmonize the behavior of 
DataFrameReader.text/csv/json/parquet/orc

## What changes were proposed in this pull request?

Issues with current reader behavior.
- `text()` without args returns an empty DF with no columns -> inconsistent, 
its expected that text will always return a DF with `value` string field,
- `textFile()` without args fails with exception because of the above reason, 
it expected the DF returned by `text()` to have a `value` field.
- `orc()` does not have var args, inconsistent with others
- `json(single-arg)` was removed, but that caused source compatibility issues - 
[SPARK-16009](https://issues.apache.org/jira/browse/SPARK-16009)
- user specified schema was not respected when `text/csv/...` were used with no 
args - [SPARK-16007](https://issues.apache.org/jira/browse/SPARK-16007)

The solution I am implementing is to do the following.
- For each format, there will be a single argument method, and a vararg method. 
For json, parquet, csv, text, this means adding json(string), etc.. For orc, 
this means adding orc(varargs).
- Remove the special handling of text(), csv(), etc. that returns empty 
dataframe with no fields. Rather pass on the empty sequence of paths to the 
datasource, and let each datasource handle it right. For e.g, text data source, 
should return empty DF with schema (value: string)
- Deduped docs and fixed their formatting.

## How was this patch tested?
Added new unit tests for Scala and Java tests

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #13727 from tdas/SPARK-15982.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b99129cc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b99129cc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b99129cc

Branch: refs/heads/master
Commit: b99129cc452defc266f6d357f5baab5f4ff37a36
Parents: 6df8e38
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Mon Jun 20 14:52:28 2016 -0700
Committer: Shixiong Zhu <shixi...@databricks.com>
Committed: Mon Jun 20 14:52:28 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/DataFrameReader.scala  | 132 +++++++++----
 .../sql/JavaDataFrameReaderWriterSuite.java     | 158 ++++++++++++++++
 .../sql/test/DataFrameReaderWriterSuite.scala   | 186 ++++++++++++++++---
 3 files changed, 420 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b99129cc/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 2ae854d..841503b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -119,13 +119,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * @since 1.4.0
    */
   def load(): DataFrame = {
-    val dataSource =
-      DataSource(
-        sparkSession,
-        userSpecifiedSchema = userSpecifiedSchema,
-        className = source,
-        options = extraOptions.toMap)
-    Dataset.ofRows(sparkSession, LogicalRelation(dataSource.resolveRelation()))
+    load(Seq.empty: _*) // force invocation of `load(...varargs...)`
   }
 
   /**
@@ -135,7 +129,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * @since 1.4.0
    */
   def load(path: String): DataFrame = {
-    option("path", path).load()
+    load(Seq(path): _*) // force invocation of `load(...varargs...)`
   }
 
   /**
@@ -146,18 +140,15 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    */
   @scala.annotation.varargs
   def load(paths: String*): DataFrame = {
-    if (paths.isEmpty) {
-      sparkSession.emptyDataFrame
-    } else {
-      sparkSession.baseRelationToDataFrame(
-        DataSource.apply(
-          sparkSession,
-          paths = paths,
-          userSpecifiedSchema = userSpecifiedSchema,
-          className = source,
-          options = extraOptions.toMap).resolveRelation())
-    }
+    sparkSession.baseRelationToDataFrame(
+      DataSource.apply(
+        sparkSession,
+        paths = paths,
+        userSpecifiedSchema = userSpecifiedSchema,
+        className = source,
+        options = extraOptions.toMap).resolveRelation())
   }
+
   /**
    * Construct a [[DataFrame]] representing the database table accessible via 
JDBC URL
    * url named table and connection properties.
@@ -247,11 +238,23 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
 
   /**
    * Loads a JSON file (one object per line) and returns the result as a 
[[DataFrame]].
+   * See the documentation on the overloaded `json()` method with varargs for 
more details.
+   *
+   * @since 1.4.0
+   */
+  def json(path: String): DataFrame = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    json(Seq(path): _*)
+  }
+
+  /**
+   * Loads a JSON file (one object per line) and returns the result as a 
[[DataFrame]].
    *
    * This function goes through the input once to determine the input schema. 
If you know the
    * schema in advance, use the version that specifies the schema to avoid the 
extra scan.
    *
    * You can set the following JSON-specific options to deal with non-standard 
JSON files:
+   * <ul>
    * <li>`primitivesAsString` (default `false`): infers all primitive values 
as a string type</li>
    * <li>`prefersDecimal` (default `false`): infers all floating-point values 
as a decimal
    * type. If the values do not fit in decimal, then it infers them as 
doubles.</li>
@@ -266,17 +269,17 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt 
records
    * during parsing.</li>
    * <ul>
-   *  <li>`PERMISSIVE` : sets other fields to `null` when it meets a corrupted 
record, and puts the
-   *  malformed string into a new field configured by 
`columnNameOfCorruptRecord`. When
+   *  <li> - `PERMISSIVE` : sets other fields to `null` when it meets a 
corrupted record, and puts
+   *  the malformed string into a new field configured by 
`columnNameOfCorruptRecord`. When
    *  a schema is set by user, it sets `null` for extra fields.</li>
-   *  <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
-   *  <li>`FAILFAST` : throws an exception when it meets corrupted 
records.</li>
+   *  <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
+   *  <li> - `FAILFAST` : throws an exception when it meets corrupted 
records.</li>
    * </ul>
    * <li>`columnNameOfCorruptRecord` (default is the value specified in
    * `spark.sql.columnNameOfCorruptRecord`): allows renaming the new field 
having malformed string
    * created by `PERMISSIVE` mode. This overrides 
`spark.sql.columnNameOfCorruptRecord`.</li>
-   *
-   * @since 1.6.0
+   * </ul>
+   * @since 2.0.0
    */
   @scala.annotation.varargs
   def json(paths: String*): DataFrame = format("json").load(paths : _*)
@@ -327,6 +330,17 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
   }
 
   /**
+   * Loads a CSV file and returns the result as a [[DataFrame]]. See the 
documentation on the
+   * other overloaded `csv()` method for more details.
+   *
+   * @since 2.0.0
+   */
+  def csv(path: String): DataFrame = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    csv(Seq(path): _*)
+  }
+
+  /**
    * Loads a CSV file and returns the result as a [[DataFrame]].
    *
    * This function will go through the input once to determine the input 
schema if `inferSchema`
@@ -334,6 +348,7 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * specify the schema explicitly using [[schema]].
    *
    * You can set the following CSV-specific options to deal with CSV files:
+   * <ul>
    * <li>`sep` (default `,`): sets the single character as a separator for each
    * field and value.</li>
    * <li>`encoding` (default `UTF-8`): decodes the CSV files by the given 
encoding
@@ -370,26 +385,37 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * <li>`mode` (default `PERMISSIVE`): allows a mode for dealing with corrupt 
records
    *    during parsing.</li>
    * <ul>
-   *   <li>`PERMISSIVE` : sets other fields to `null` when it meets a 
corrupted record. When
+   *   <li> - `PERMISSIVE` : sets other fields to `null` when it meets a 
corrupted record. When
    *     a schema is set by user, it sets `null` for extra fields.</li>
-   *   <li>`DROPMALFORMED` : ignores the whole corrupted records.</li>
-   *   <li>`FAILFAST` : throws an exception when it meets corrupted 
records.</li>
+   *   <li> - `DROPMALFORMED` : ignores the whole corrupted records.</li>
+   *   <li> - `FAILFAST` : throws an exception when it meets corrupted 
records.</li>
+   * </ul>
    * </ul>
-   *
    * @since 2.0.0
    */
   @scala.annotation.varargs
   def csv(paths: String*): DataFrame = format("csv").load(paths : _*)
 
   /**
-   * Loads a Parquet file, returning the result as a [[DataFrame]]. This 
function returns an empty
-   * [[DataFrame]] if no paths are passed in.
+   * Loads a Parquet file, returning the result as a [[DataFrame]]. See the 
documentation
+   * on the other overloaded `parquet()` method for more details.
+   *
+   * @since 2.0.0
+   */
+  def parquet(path: String): DataFrame = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    parquet(Seq(path): _*)
+  }
+
+  /**
+   * Loads a Parquet file, returning the result as a [[DataFrame]].
    *
    * You can set the following Parquet-specific option(s) for reading Parquet 
files:
+   * <ul>
    * <li>`mergeSchema` (default is the value specified in 
`spark.sql.parquet.mergeSchema`): sets
    * whether we should merge schemas collected from all Parquet part-files. 
This will override
    * `spark.sql.parquet.mergeSchema`.</li>
-   *
+   * </ul>
    * @since 1.4.0
    */
   @scala.annotation.varargs
@@ -404,7 +430,20 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    * @since 1.5.0
    * @note Currently, this method can only be used after enabling Hive support.
    */
-  def orc(path: String): DataFrame = format("orc").load(path)
+  def orc(path: String): DataFrame = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    orc(Seq(path): _*)
+  }
+
+  /**
+   * Loads an ORC file and returns the result as a [[DataFrame]].
+   *
+   * @param paths input paths
+   * @since 2.0.0
+   * @note Currently, this method can only be used after enabling Hive support.
+   */
+  @scala.annotation.varargs
+  def orc(paths: String*): DataFrame = format("orc").load(paths: _*)
 
   /**
    * Returns the specified table as a [[DataFrame]].
@@ -419,6 +458,18 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
 
   /**
    * Loads text files and returns a [[DataFrame]] whose schema starts with a 
string column named
+   * "value", and followed by partitioned columns if there are any. See the 
documentation on
+   * the other overloaded `text()` method for more details.
+   *
+   * @since 2.0.0
+   */
+  def text(path: String): DataFrame = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    text(Seq(path): _*)
+  }
+
+  /**
+   * Loads text files and returns a [[DataFrame]] whose schema starts with a 
string column named
    * "value", and followed by partitioned columns if there are any.
    *
    * Each line in the text files is a new row in the resulting DataFrame. For 
example:
@@ -430,13 +481,23 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    *   spark.read().text("/path/to/spark/README.md")
    * }}}
    *
-   * @param paths input path
+   * @param paths input paths
    * @since 1.6.0
    */
   @scala.annotation.varargs
   def text(paths: String*): DataFrame = format("text").load(paths : _*)
 
   /**
+   * Loads text files and returns a [[Dataset]] of String. See the 
documentation on the
+   * other overloaded `textFile()` method for more details.
+   * @since 2.0.0
+   */
+  def textFile(path: String): Dataset[String] = {
+    // This method ensures that calls that explicit need single argument 
works, see SPARK-16009
+    textFile(Seq(path): _*)
+  }
+
+  /**
    * Loads text files and returns a [[Dataset]] of String. The underlying 
schema of the Dataset
    * contains a single string column named "value".
    *
@@ -457,6 +518,9 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
    */
   @scala.annotation.varargs
   def textFile(paths: String*): Dataset[String] = {
+    if (userSpecifiedSchema.nonEmpty) {
+      throw new AnalysisException("User specified schema not supported with 
`textFile`")
+    }
     text(paths : 
_*).select("value").as[String](sparkSession.implicits.newStringEncoder)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b99129cc/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
new file mode 100644
index 0000000..7babf75
--- /dev/null
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameReaderWriterSuite.java
@@ -0,0 +1,158 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements.  See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License.  You may obtain a copy of the License at
+*
+*    http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package test.org.apache.spark.sql;
+
+import java.io.File;
+import java.util.HashMap;
+
+import org.apache.spark.sql.SaveMode;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.test.TestSparkSession;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class JavaDataFrameReaderWriterSuite {
+  private SparkSession spark = new TestSparkSession();
+  private StructType schema = new StructType().add("s", "string");
+  private transient String input;
+  private transient String output;
+
+  @Before
+  public void setUp() {
+    input = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"input").toString();
+    File f = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"output");
+    f.delete();
+    output = f.toString();
+  }
+
+  @After
+  public void tearDown() {
+    spark.stop();
+    spark = null;
+  }
+
+  @Test
+  public void testFormatAPI() {
+    spark
+        .read()
+        .format("org.apache.spark.sql.test")
+        .load()
+        .write()
+        .format("org.apache.spark.sql.test")
+        .save();
+  }
+
+  @Test
+  public void testOptionsAPI() {
+    HashMap<String, String> map = new HashMap<String, String>();
+    map.put("e", "1");
+    spark
+        .read()
+        .option("a", "1")
+        .option("b", 1)
+        .option("c", 1.0)
+        .option("d", true)
+        .options(map)
+        .text()
+        .write()
+        .option("a", "1")
+        .option("b", 1)
+        .option("c", 1.0)
+        .option("d", true)
+        .options(map)
+        .format("org.apache.spark.sql.test")
+        .save();
+  }
+
+  @Test
+  public void testSaveModeAPI() {
+    spark
+        .range(10)
+        .write()
+        .format("org.apache.spark.sql.test")
+        .mode(SaveMode.ErrorIfExists)
+        .save();
+  }
+
+  @Test
+  public void testLoadAPI() {
+    spark.read().format("org.apache.spark.sql.test").load();
+    spark.read().format("org.apache.spark.sql.test").load(input);
+    spark.read().format("org.apache.spark.sql.test").load(input, input, input);
+    spark.read().format("org.apache.spark.sql.test").load(new String[]{input, 
input});
+  }
+
+  @Test
+  public void testTextAPI() {
+    spark.read().text();
+    spark.read().text(input);
+    spark.read().text(input, input, input);
+    spark.read().text(new String[]{input, input})
+        .write().text(output);
+  }
+
+  @Test
+  public void testTextFileAPI() {
+    spark.read().textFile();
+    spark.read().textFile(input);
+    spark.read().textFile(input, input, input);
+    spark.read().textFile(new String[]{input, input});
+  }
+
+  @Test
+  public void testCsvAPI() {
+    spark.read().schema(schema).csv();
+    spark.read().schema(schema).csv(input);
+    spark.read().schema(schema).csv(input, input, input);
+    spark.read().schema(schema).csv(new String[]{input, input})
+        .write().csv(output);
+  }
+
+  @Test
+  public void testJsonAPI() {
+    spark.read().schema(schema).json();
+    spark.read().schema(schema).json(input);
+    spark.read().schema(schema).json(input, input, input);
+    spark.read().schema(schema).json(new String[]{input, input})
+        .write().json(output);
+  }
+
+  @Test
+  public void testParquetAPI() {
+    spark.read().schema(schema).parquet();
+    spark.read().schema(schema).parquet(input);
+    spark.read().schema(schema).parquet(input, input, input);
+    spark.read().schema(schema).parquet(new String[] { input, input })
+        .write().parquet(output);
+  }
+
+  /**
+   * This only tests whether API compiles, but does not run it as orc()
+   * cannot be run without Hive classes.
+   */
+  public void testOrcAPI() {
+    spark.read().schema(schema).orc();
+    spark.read().schema(schema).orc(input);
+    spark.read().schema(schema).orc(input, input, input);
+    spark.read().schema(schema).orc(new String[]{input, input})
+        .write().orc(output);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b99129cc/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 98e57b3..3fa3864 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -17,6 +17,10 @@
 
 package org.apache.spark.sql.test
 
+import java.io.File
+
+import org.scalatest.BeforeAndAfter
+
 import org.apache.spark.sql._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.{StringType, StructField, StructType}
@@ -79,10 +83,19 @@ class DefaultSource
 }
 
 
-class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext {
+class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with 
BeforeAndAfter {
+
 
-  private def newMetadataDir =
-    Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+  private val userSchema = new StructType().add("s", StringType)
+  private val textSchema = new StructType().add("value", StringType)
+  private val data = Seq("1", "2", "3")
+  private val dir = Utils.createTempDir(namePrefix = "input").getCanonicalPath
+  private implicit var enc: Encoder[String] = _
+
+  before {
+    enc = spark.implicits.newStringEncoder
+    Utils.deleteRecursively(new File(dir))
+  }
 
   test("writeStream cannot be called on non-streaming datasets") {
     val e = intercept[AnalysisException] {
@@ -157,24 +170,6 @@ class DataFrameReaderWriterSuite extends QueryTest with 
SharedSQLContext {
     assert(LastOptions.saveMode === SaveMode.ErrorIfExists)
   }
 
-  test("paths") {
-    val df = spark.read
-      .format("org.apache.spark.sql.test")
-      .option("checkpointLocation", newMetadataDir)
-      .load("/test")
-
-    assert(LastOptions.parameters("path") == "/test")
-
-    LastOptions.clear()
-
-    df.write
-      .format("org.apache.spark.sql.test")
-      .option("checkpointLocation", newMetadataDir)
-      .save("/test")
-
-    assert(LastOptions.parameters("path") == "/test")
-  }
-
   test("test different data types for options") {
     val df = spark.read
       .format("org.apache.spark.sql.test")
@@ -193,7 +188,6 @@ class DataFrameReaderWriterSuite extends QueryTest with 
SharedSQLContext {
       .option("intOpt", 56)
       .option("boolOpt", false)
       .option("doubleOpt", 6.7)
-      .option("checkpointLocation", newMetadataDir)
       .save("/test")
 
     assert(LastOptions.parameters("intOpt") == "56")
@@ -228,4 +222,152 @@ class DataFrameReaderWriterSuite extends QueryTest with 
SharedSQLContext {
       }
     }
   }
+
+  test("load API") {
+    spark.read.format("org.apache.spark.sql.test").load()
+    spark.read.format("org.apache.spark.sql.test").load(dir)
+    spark.read.format("org.apache.spark.sql.test").load(dir, dir, dir)
+    spark.read.format("org.apache.spark.sql.test").load(Seq(dir, dir): _*)
+    Option(dir).map(spark.read.format("org.apache.spark.sql.test").load)
+  }
+
+  test("text - API and behavior regarding schema") {
+    // Writer
+    spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
+    testRead(spark.read.text(dir), data, textSchema)
+
+    // Reader, without user specified schema
+    testRead(spark.read.text(), Seq.empty, textSchema)
+    testRead(spark.read.text(dir, dir, dir), data ++ data ++ data, textSchema)
+    testRead(spark.read.text(Seq(dir, dir): _*), data ++ data, textSchema)
+    // Test explicit calls to single arg method - SPARK-16009
+    testRead(Option(dir).map(spark.read.text).get, data, textSchema)
+
+    // Reader, with user specified schema, should just apply user schema on 
the file data
+    testRead(spark.read.schema(userSchema).text(), Seq.empty, userSchema)
+    testRead(spark.read.schema(userSchema).text(dir), data, userSchema)
+    testRead(spark.read.schema(userSchema).text(dir, dir), data ++ data, 
userSchema)
+    testRead(spark.read.schema(userSchema).text(Seq(dir, dir): _*), data ++ 
data, userSchema)
+  }
+
+  test("textFile - API and behavior regarding schema") {
+    spark.createDataset(data).write.mode(SaveMode.Overwrite).text(dir)
+
+    // Reader, without user specified schema
+    testRead(spark.read.textFile().toDF(), Seq.empty, textSchema)
+    testRead(spark.read.textFile(dir).toDF(), data, textSchema)
+    testRead(spark.read.textFile(dir, dir).toDF(), data ++ data, textSchema)
+    testRead(spark.read.textFile(Seq(dir, dir): _*).toDF(), data ++ data, 
textSchema)
+    // Test explicit calls to single arg method - SPARK-16009
+    testRead(Option(dir).map(spark.read.text).get, data, textSchema)
+
+    // Reader, with user specified schema, should just apply user schema on 
the file data
+    val e = intercept[AnalysisException] { 
spark.read.schema(userSchema).textFile() }
+    assert(e.getMessage.toLowerCase.contains("user specified schema not 
supported"))
+    intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir) 
}
+    intercept[AnalysisException] { spark.read.schema(userSchema).textFile(dir, 
dir) }
+    intercept[AnalysisException] { 
spark.read.schema(userSchema).textFile(Seq(dir, dir): _*) }
+  }
+
+  test("csv - API and behavior regarding schema") {
+    // Writer
+    
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).csv(dir)
+    val df = spark.read.csv(dir)
+    checkAnswer(df, spark.createDataset(data).toDF())
+    val schema = df.schema
+
+    // Reader, without user specified schema
+    intercept[IllegalArgumentException] {
+      testRead(spark.read.csv(), Seq.empty, schema)
+    }
+    testRead(spark.read.csv(dir), data, schema)
+    testRead(spark.read.csv(dir, dir), data ++ data, schema)
+    testRead(spark.read.csv(Seq(dir, dir): _*), data ++ data, schema)
+    // Test explicit calls to single arg method - SPARK-16009
+    testRead(Option(dir).map(spark.read.csv).get, data, schema)
+
+    // Reader, with user specified schema, should just apply user schema on 
the file data
+    testRead(spark.read.schema(userSchema).csv(), Seq.empty, userSchema)
+    testRead(spark.read.schema(userSchema).csv(dir), data, userSchema)
+    testRead(spark.read.schema(userSchema).csv(dir, dir), data ++ data, 
userSchema)
+    testRead(spark.read.schema(userSchema).csv(Seq(dir, dir): _*), data ++ 
data, userSchema)
+  }
+
+  test("json - API and behavior regarding schema") {
+    // Writer
+    
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).json(dir)
+    val df = spark.read.json(dir)
+    checkAnswer(df, spark.createDataset(data).toDF())
+    val schema = df.schema
+
+    // Reader, without user specified schema
+    intercept[AnalysisException] {
+      testRead(spark.read.json(), Seq.empty, schema)
+    }
+    testRead(spark.read.json(dir), data, schema)
+    testRead(spark.read.json(dir, dir), data ++ data, schema)
+    testRead(spark.read.json(Seq(dir, dir): _*), data ++ data, schema)
+    // Test explicit calls to single arg method - SPARK-16009
+    testRead(Option(dir).map(spark.read.json).get, data, schema)
+
+    // Reader, with user specified schema, data should be nulls as schema in 
file different
+    // from user schema
+    val expData = Seq[String](null, null, null)
+    testRead(spark.read.schema(userSchema).json(), Seq.empty, userSchema)
+    testRead(spark.read.schema(userSchema).json(dir), expData, userSchema)
+    testRead(spark.read.schema(userSchema).json(dir, dir), expData ++ expData, 
userSchema)
+    testRead(spark.read.schema(userSchema).json(Seq(dir, dir): _*), expData ++ 
expData, userSchema)
+  }
+
+  test("parquet - API and behavior regarding schema") {
+    // Writer
+    
spark.createDataset(data).toDF("str").write.mode(SaveMode.Overwrite).parquet(dir)
+    val df = spark.read.parquet(dir)
+    checkAnswer(df, spark.createDataset(data).toDF())
+    val schema = df.schema
+
+    // Reader, without user specified schema
+    intercept[AnalysisException] {
+      testRead(spark.read.parquet(), Seq.empty, schema)
+    }
+    testRead(spark.read.parquet(dir), data, schema)
+    testRead(spark.read.parquet(dir, dir), data ++ data, schema)
+    testRead(spark.read.parquet(Seq(dir, dir): _*), data ++ data, schema)
+    // Test explicit calls to single arg method - SPARK-16009
+    testRead(Option(dir).map(spark.read.parquet).get, data, schema)
+
+    // Reader, with user specified schema, data should be nulls as schema in 
file different
+    // from user schema
+    val expData = Seq[String](null, null, null)
+    testRead(spark.read.schema(userSchema).parquet(), Seq.empty, userSchema)
+    testRead(spark.read.schema(userSchema).parquet(dir), expData, userSchema)
+    testRead(spark.read.schema(userSchema).parquet(dir, dir), expData ++ 
expData, userSchema)
+    testRead(
+      spark.read.schema(userSchema).parquet(Seq(dir, dir): _*), expData ++ 
expData, userSchema)
+  }
+
+  /**
+   * This only tests whether API compiles, but does not run it as orc()
+   * cannot be run without Hive classes.
+   */
+  ignore("orc - API") {
+    // Reader, with user specified schema
+    // Refer to csv-specific test suites for behavior without user specified 
schema
+    spark.read.schema(userSchema).orc()
+    spark.read.schema(userSchema).orc(dir)
+    spark.read.schema(userSchema).orc(dir, dir, dir)
+    spark.read.schema(userSchema).orc(Seq(dir, dir): _*)
+    Option(dir).map(spark.read.schema(userSchema).orc)
+
+    // Writer
+    spark.range(10).write.orc(dir)
+  }
+
+  private def testRead(
+      df: => DataFrame,
+      expectedResult: Seq[String],
+      expectedSchema: StructType): Unit = {
+    checkAnswer(df, spark.createDataset(expectedResult).toDF())
+    assert(df.schema === expectedSchema)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to