Repository: spark
Updated Branches:
refs/heads/master a00181418 -> 39872af88
[SPARK-25684][SQL] Organize header related codes in CSV datasource
## What changes were proposed in this pull request?
1. Move `CSVDataSource.makeSafeHeader` to `CSVUtils.makeSafeHeader` (as is).
- Historically and at the first place of refactoring (which I did), I
intended to put all CSV specific handling (like options), filtering, extracting
header, etc.
- See `JsonDataSource`. Now `CSVDataSource` is quite consistent with
`JsonDataSource`. Since CSV's code path is quite complicated, we might better
match them as possible as we can.
2. Create `CSVHeaderChecker` and put `enforceSchema` logics into that.
- The checking header and column pruning stuff were added (per
https://github.com/apache/spark/pull/20894 and
https://github.com/apache/spark/pull/21296) but some of codes such as
https://github.com/apache/spark/pull/22123 are duplicated
- Also, checking header code is basically here and there. We better put
them in a single place, which was quite error-prone. See
(https://github.com/apache/spark/pull/22656).
3. Move `CSVDataSource.checkHeaderColumnNames` to
`CSVHeaderChecker.checkHeaderColumnNames` (as is).
- Similar reasons above with 1.
## How was this patch tested?
Existing tests should cover this.
Closes #22676 from HyukjinKwon/refactoring-csv.
Authored-by: hyukjinkwon <[email protected]>
Signed-off-by: hyukjinkwon <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/39872af8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/39872af8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/39872af8
Branch: refs/heads/master
Commit: 39872af882e3d73667acfab93c9de962c9c8939d
Parents: a001814
Author: hyukjinkwon <[email protected]>
Authored: Fri Oct 12 09:16:41 2018 +0800
Committer: hyukjinkwon <[email protected]>
Committed: Fri Oct 12 09:16:41 2018 +0800
----------------------------------------------------------------------
.../org/apache/spark/sql/DataFrameReader.scala | 18 +--
.../datasources/csv/CSVDataSource.scala | 161 ++-----------------
.../datasources/csv/CSVFileFormat.scala | 11 +-
.../datasources/csv/CSVHeaderChecker.scala | 131 +++++++++++++++
.../execution/datasources/csv/CSVUtils.scala | 44 ++++-
.../datasources/csv/UnivocityParser.scala | 34 ++--
6 files changed, 217 insertions(+), 182 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 7269446..3af70b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -505,20 +505,14 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
val actualSchema =
StructType(schema.filterNot(_.name ==
parsedOptions.columnNameOfCorruptRecord))
- val linesWithoutHeader = if (parsedOptions.headerFlag &&
maybeFirstLine.isDefined) {
- val firstLine = maybeFirstLine.get
- val parser = new CsvParser(parsedOptions.asParserSettings)
- val columnNames = parser.parseLine(firstLine)
- CSVDataSource.checkHeaderColumnNames(
+ val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
+ val headerChecker = new CSVHeaderChecker(
actualSchema,
- columnNames,
- csvDataset.getClass.getCanonicalName,
- parsedOptions.enforceSchema,
- sparkSession.sessionState.conf.caseSensitiveAnalysis)
+ parsedOptions,
+ source = s"CSV source: $csvDataset")
+ headerChecker.checkHeaderColumnNames(firstLine)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine,
parsedOptions))
- } else {
- filteredLines.rdd
- }
+ }.getOrElse(filteredLines.rdd)
val parsed = linesWithoutHeader.mapPartitions { iter =>
val rawParser = new UnivocityParser(actualSchema, parsedOptions)
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index b93f418..0b5a719 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -51,11 +51,8 @@ abstract class CSVDataSource extends Serializable {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- requiredSchema: StructType,
- // Actual schema of data in the csv file
- dataSchema: StructType,
- caseSensitive: Boolean,
- columnPruning: Boolean): Iterator[InternalRow]
+ headerChecker: CSVHeaderChecker,
+ requiredSchema: StructType): Iterator[InternalRow]
/**
* Infers the schema from `inputPaths` files.
@@ -75,48 +72,6 @@ abstract class CSVDataSource extends Serializable {
sparkSession: SparkSession,
inputPaths: Seq[FileStatus],
parsedOptions: CSVOptions): StructType
-
- /**
- * Generates a header from the given row which is null-safe and
duplicate-safe.
- */
- protected def makeSafeHeader(
- row: Array[String],
- caseSensitive: Boolean,
- options: CSVOptions): Array[String] = {
- if (options.headerFlag) {
- val duplicates = {
- val headerNames = row.filter(_ != null)
- // scalastyle:off caselocale
- .map(name => if (caseSensitive) name else name.toLowerCase)
- // scalastyle:on caselocale
- headerNames.diff(headerNames.distinct).distinct
- }
-
- row.zipWithIndex.map { case (value, index) =>
- if (value == null || value.isEmpty || value == options.nullValue) {
- // When there are empty strings or the values set in `nullValue`,
put the
- // index as the suffix.
- s"_c$index"
- // scalastyle:off caselocale
- } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
- // scalastyle:on caselocale
- // When there are case-insensitive duplicates, put the index as the
suffix.
- s"$value$index"
- } else if (duplicates.contains(value)) {
- // When there are duplicates, put the index as the suffix.
- s"$value$index"
- } else {
- value
- }
- }
- } else {
- row.zipWithIndex.map { case (_, index) =>
- // Uses default column names, "_c#" where # is its position of fields
- // when header option is disabled.
- s"_c$index"
- }
- }
- }
}
object CSVDataSource extends Logging {
@@ -127,67 +82,6 @@ object CSVDataSource extends Logging {
TextInputCSVDataSource
}
}
-
- /**
- * Checks that column names in a CSV header and field names in the schema
are the same
- * by taking into account case sensitivity.
- *
- * @param schema - provided (or inferred) schema to which CSV must conform.
- * @param columnNames - names of CSV columns that must be checked against to
the schema.
- * @param fileName - name of CSV file that are currently checked. It is used
in error messages.
- * @param enforceSchema - if it is `true`, column names are ignored
otherwise the CSV column
- * names are checked for conformance to the schema.
In the case if
- * the column name don't conform to the schema, an
exception is thrown.
- * @param caseSensitive - if it is set to `false`, comparison of column
names and schema field
- * names is not case sensitive.
- */
- def checkHeaderColumnNames(
- schema: StructType,
- columnNames: Array[String],
- fileName: String,
- enforceSchema: Boolean,
- caseSensitive: Boolean): Unit = {
- if (columnNames != null) {
- val fieldNames = schema.map(_.name).toIndexedSeq
- val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
- var errorMessage: Option[String] = None
-
- if (headerLen == schemaSize) {
- var i = 0
- while (errorMessage.isEmpty && i < headerLen) {
- var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
- if (!caseSensitive) {
- // scalastyle:off caselocale
- nameInSchema = nameInSchema.toLowerCase
- nameInHeader = nameInHeader.toLowerCase
- // scalastyle:on caselocale
- }
- if (nameInHeader != nameInSchema) {
- errorMessage = Some(
- s"""|CSV header does not conform to the schema.
- | Header: ${columnNames.mkString(", ")}
- | Schema: ${fieldNames.mkString(", ")}
- |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
- |CSV file: $fileName""".stripMargin)
- }
- i += 1
- }
- } else {
- errorMessage = Some(
- s"""|Number of column in CSV header is not equal to number of fields
in the schema:
- | Header length: $headerLen, schema size: $schemaSize
- |CSV file: $fileName""".stripMargin)
- }
-
- errorMessage.foreach { msg =>
- if (enforceSchema) {
- logWarning(msg)
- } else {
- throw new IllegalArgumentException(msg)
- }
- }
- }
- }
}
object TextInputCSVDataSource extends CSVDataSource {
@@ -197,10 +91,8 @@ object TextInputCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- requiredSchema: StructType,
- dataSchema: StructType,
- caseSensitive: Boolean,
- columnPruning: Boolean): Iterator[InternalRow] = {
+ headerChecker: CSVHeaderChecker,
+ requiredSchema: StructType): Iterator[InternalRow] = {
val lines = {
val linesReader = new HadoopFileLinesReader(file, conf)
Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ =>
linesReader.close()))
@@ -209,25 +101,7 @@ object TextInputCSVDataSource extends CSVDataSource {
}
}
- val hasHeader = parser.options.headerFlag && file.start == 0
- if (hasHeader) {
- // Checking that column names in the header are matched to field names
of the schema.
- // The header will be removed from lines.
- // Note: if there are only comments in the first block, the header would
probably
- // be not extracted.
- CSVUtils.extractHeader(lines, parser.options).foreach { header =>
- val schema = if (columnPruning) requiredSchema else dataSchema
- val columnNames = parser.tokenizer.parseLine(header)
- CSVDataSource.checkHeaderColumnNames(
- schema,
- columnNames,
- file.filePath,
- parser.options.enforceSchema,
- caseSensitive)
- }
- }
-
- UnivocityParser.parseIterator(lines, parser, requiredSchema)
+ UnivocityParser.parseIterator(lines, parser, headerChecker, requiredSchema)
}
override def infer(
@@ -251,7 +125,7 @@ object TextInputCSVDataSource extends CSVDataSource {
maybeFirstLine.map(csvParser.parseLine(_)) match {
case Some(firstRow) if firstRow != null =>
val caseSensitive =
sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive,
parsedOptions)
val sampled: Dataset[String] = CSVUtils.sample(csv, parsedOptions)
val tokenRDD = sampled.rdd.mapPartitions { iter =>
val filteredLines = CSVUtils.filterCommentAndEmpty(iter,
parsedOptions)
@@ -298,26 +172,13 @@ object MultiLineCSVDataSource extends CSVDataSource {
conf: Configuration,
file: PartitionedFile,
parser: UnivocityParser,
- requiredSchema: StructType,
- dataSchema: StructType,
- caseSensitive: Boolean,
- columnPruning: Boolean): Iterator[InternalRow] = {
- def checkHeader(header: Array[String]): Unit = {
- val schema = if (columnPruning) requiredSchema else dataSchema
- CSVDataSource.checkHeaderColumnNames(
- schema,
- header,
- file.filePath,
- parser.options.enforceSchema,
- caseSensitive)
- }
-
+ headerChecker: CSVHeaderChecker,
+ requiredSchema: StructType): Iterator[InternalRow] = {
UnivocityParser.parseStream(
CodecStreams.createInputStreamWithCloseResource(conf, new Path(new
URI(file.filePath))),
- parser.options.headerFlag,
parser,
- requiredSchema,
- checkHeader)
+ headerChecker,
+ requiredSchema)
}
override def infer(
@@ -334,7 +195,7 @@ object MultiLineCSVDataSource extends CSVDataSource {
}.take(1).headOption match {
case Some(firstRow) =>
val caseSensitive =
sparkSession.sessionState.conf.caseSensitiveAnalysis
- val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions)
+ val header = CSVUtils.makeSafeHeader(firstRow, caseSensitive,
parsedOptions)
val tokenRDD = csv.flatMap { lines =>
UnivocityParser.tokenizeStream(
CodecStreams.createInputStreamWithCloseResource(
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index 9aad0bd..3de1c2d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -130,7 +130,6 @@ class CSVFileFormat extends TextBasedFileFormat with
DataSourceRegister {
"df.filter($\"_corrupt_record\".isNotNull).count()."
)
}
- val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
val columnPruning = sparkSession.sessionState.conf.csvColumnPruning
(file: PartitionedFile) => {
@@ -139,14 +138,16 @@ class CSVFileFormat extends TextBasedFileFormat with
DataSourceRegister {
StructType(dataSchema.filterNot(_.name ==
parsedOptions.columnNameOfCorruptRecord)),
StructType(requiredSchema.filterNot(_.name ==
parsedOptions.columnNameOfCorruptRecord)),
parsedOptions)
+ val schema = if (columnPruning) requiredSchema else dataSchema
+ val isStartOfFile = file.start == 0
+ val headerChecker = new CSVHeaderChecker(
+ schema, parsedOptions, source = s"CSV file: ${file.filePath}",
isStartOfFile)
CSVDataSource(parsedOptions).readFile(
conf,
file,
parser,
- requiredSchema,
- dataSchema,
- caseSensitive,
- columnPruning)
+ headerChecker,
+ requiredSchema)
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
new file mode 100644
index 0000000..558ee91
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVHeaderChecker.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.csv
+
+import com.univocity.parsers.csv.CsvParser
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Checks that column names in a CSV header and field names in the schema are
the same
+ * by taking into account case sensitivity.
+ *
+ * @param schema provided (or inferred) schema to which CSV must conform.
+ * @param options parsed CSV options.
+ * @param source name of CSV source that are currently checked. It is used in
error messages.
+ * @param isStartOfFile indicates if the currently processing partition is the
start of the file.
+ * if unknown or not applicable (for instance when the
input is a dataset),
+ * can be omitted.
+ */
+class CSVHeaderChecker(
+ schema: StructType,
+ options: CSVOptions,
+ source: String,
+ isStartOfFile: Boolean = false) extends Logging {
+
+ // Indicates if it is set to `false`, comparison of column names and schema
field
+ // names is not case sensitive.
+ private val caseSensitive = SQLConf.get.caseSensitiveAnalysis
+
+ // Indicates if it is `true`, column names are ignored otherwise the CSV
column
+ // names are checked for conformance to the schema. In the case if
+ // the column name don't conform to the schema, an exception is thrown.
+ private val enforceSchema = options.enforceSchema
+
+ /**
+ * Checks that column names in a CSV header and field names in the schema
are the same
+ * by taking into account case sensitivity.
+ *
+ * @param columnNames names of CSV columns that must be checked against to
the schema.
+ */
+ private def checkHeaderColumnNames(columnNames: Array[String]): Unit = {
+ if (columnNames != null) {
+ val fieldNames = schema.map(_.name).toIndexedSeq
+ val (headerLen, schemaSize) = (columnNames.size, fieldNames.length)
+ var errorMessage: Option[String] = None
+
+ if (headerLen == schemaSize) {
+ var i = 0
+ while (errorMessage.isEmpty && i < headerLen) {
+ var (nameInSchema, nameInHeader) = (fieldNames(i), columnNames(i))
+ if (!caseSensitive) {
+ // scalastyle:off caselocale
+ nameInSchema = nameInSchema.toLowerCase
+ nameInHeader = nameInHeader.toLowerCase
+ // scalastyle:on caselocale
+ }
+ if (nameInHeader != nameInSchema) {
+ errorMessage = Some(
+ s"""|CSV header does not conform to the schema.
+ | Header: ${columnNames.mkString(", ")}
+ | Schema: ${fieldNames.mkString(", ")}
+ |Expected: ${fieldNames(i)} but found: ${columnNames(i)}
+ |$source""".stripMargin)
+ }
+ i += 1
+ }
+ } else {
+ errorMessage = Some(
+ s"""|Number of column in CSV header is not equal to number of fields
in the schema:
+ | Header length: $headerLen, schema size: $schemaSize
+ |$source""".stripMargin)
+ }
+
+ errorMessage.foreach { msg =>
+ if (enforceSchema) {
+ logWarning(msg)
+ } else {
+ throw new IllegalArgumentException(msg)
+ }
+ }
+ }
+ }
+
+ // This is currently only used to parse CSV from Dataset[String].
+ def checkHeaderColumnNames(line: String): Unit = {
+ if (options.headerFlag) {
+ val parser = new CsvParser(options.asParserSettings)
+ checkHeaderColumnNames(parser.parseLine(line))
+ }
+ }
+
+ // This is currently only used to parse CSV with multiLine mode.
+ private[csv] def checkHeaderColumnNames(tokenizer: CsvParser): Unit = {
+ assert(options.multiLine, "This method should be executed with multiLine.")
+ if (options.headerFlag) {
+ val firstRecord = tokenizer.parseNext()
+ checkHeaderColumnNames(firstRecord)
+ }
+ }
+
+ // This is currently only used to parse CSV with non-multiLine mode.
+ private[csv] def checkHeaderColumnNames(lines: Iterator[String], tokenizer:
CsvParser): Unit = {
+ assert(!options.multiLine, "This method should not be executed with
multiline.")
+ // Checking that column names in the header are matched to field names of
the schema.
+ // The header will be removed from lines.
+ // Note: if there are only comments in the first block, the header would
probably
+ // be not extracted.
+ if (options.headerFlag && isStartOfFile) {
+ CSVUtils.extractHeader(lines, options).foreach { header =>
+ checkHeaderColumnNames(tokenizer.parseLine(header))
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
index 7ce65fa..b912f8a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.datasources.csv
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
object CSVUtils {
/**
@@ -90,6 +89,49 @@ object CSVUtils {
None
}
}
+
+ /**
+ * Generates a header from the given row which is null-safe and
duplicate-safe.
+ */
+ def makeSafeHeader(
+ row: Array[String],
+ caseSensitive: Boolean,
+ options: CSVOptions): Array[String] = {
+ if (options.headerFlag) {
+ val duplicates = {
+ val headerNames = row.filter(_ != null)
+ // scalastyle:off caselocale
+ .map(name => if (caseSensitive) name else name.toLowerCase)
+ // scalastyle:on caselocale
+ headerNames.diff(headerNames.distinct).distinct
+ }
+
+ row.zipWithIndex.map { case (value, index) =>
+ if (value == null || value.isEmpty || value == options.nullValue) {
+ // When there are empty strings or the values set in `nullValue`,
put the
+ // index as the suffix.
+ s"_c$index"
+ // scalastyle:off caselocale
+ } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) {
+ // scalastyle:on caselocale
+ // When there are case-insensitive duplicates, put the index as the
suffix.
+ s"$value$index"
+ } else if (duplicates.contains(value)) {
+ // When there are duplicates, put the index as the suffix.
+ s"$value$index"
+ } else {
+ value
+ }
+ }
+ } else {
+ row.zipWithIndex.map { case (_, index) =>
+ // Uses default column names, "_c#" where # is its position of fields
+ // when header option is disabled.
+ s"_c$index"
+ }
+ }
+ }
+
/**
* Helper method that converts string representation of a character to
actual character.
* It handles some Java escaped strings and throws exception if given string
is longer than one
http://git-wip-us.apache.org/repos/asf/spark/blob/39872af8/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
index 9088d43..fbd19c6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala
@@ -273,7 +273,10 @@ private[csv] object UnivocityParser {
inputStream: InputStream,
shouldDropHeader: Boolean,
tokenizer: CsvParser): Iterator[Array[String]] = {
- convertStream(inputStream, shouldDropHeader, tokenizer)(tokens => tokens)
+ val handleHeader: () => Unit =
+ () => if (shouldDropHeader) tokenizer.parseNext
+
+ convertStream(inputStream, tokenizer, handleHeader)(tokens => tokens)
}
/**
@@ -281,10 +284,9 @@ private[csv] object UnivocityParser {
*/
def parseStream(
inputStream: InputStream,
- shouldDropHeader: Boolean,
parser: UnivocityParser,
- schema: StructType,
- checkHeader: Array[String] => Unit): Iterator[InternalRow] = {
+ headerChecker: CSVHeaderChecker,
+ schema: StructType): Iterator[InternalRow] = {
val tokenizer = parser.tokenizer
val safeParser = new FailureSafeParser[Array[String]](
input => Seq(parser.convert(input)),
@@ -292,25 +294,26 @@ private[csv] object UnivocityParser {
schema,
parser.options.columnNameOfCorruptRecord,
parser.options.multiLine)
- convertStream(inputStream, shouldDropHeader, tokenizer, checkHeader) {
tokens =>
+
+ val handleHeader: () => Unit =
+ () => headerChecker.checkHeaderColumnNames(tokenizer)
+
+ convertStream(inputStream, tokenizer, handleHeader) { tokens =>
safeParser.parse(tokens)
}.flatten
}
private def convertStream[T](
inputStream: InputStream,
- shouldDropHeader: Boolean,
tokenizer: CsvParser,
- checkHeader: Array[String] => Unit = _ => ())(
+ handleHeader: () => Unit)(
convert: Array[String] => T) = new Iterator[T] {
tokenizer.beginParsing(inputStream)
- private var nextRecord = {
- if (shouldDropHeader) {
- val firstRecord = tokenizer.parseNext()
- checkHeader(firstRecord)
- }
- tokenizer.parseNext()
- }
+
+ // We can handle header here since here the stream is open.
+ handleHeader()
+
+ private var nextRecord = tokenizer.parseNext()
override def hasNext: Boolean = nextRecord != null
@@ -330,7 +333,10 @@ private[csv] object UnivocityParser {
def parseIterator(
lines: Iterator[String],
parser: UnivocityParser,
+ headerChecker: CSVHeaderChecker,
schema: StructType): Iterator[InternalRow] = {
+ headerChecker.checkHeaderColumnNames(lines, parser.tokenizer)
+
val options = parser.options
val filteredLines: Iterator[String] =
CSVUtils.filterCommentAndEmpty(lines, options)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]