zhztheplayer commented on code in PR #5447:
URL: https://github.com/apache/incubator-gluten/pull/5447#discussion_r1578836085
##########
shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala:
##########
@@ -1460,6 +1462,13 @@ object GlutenConfig {
.booleanConf
.createOptional
+ val NATIVE_ARROW_READER_ENABLED =
+ buildConf("spark.gluten.sql.native.arrow.reader.enabled")
+ .internal()
+ .doc("This is config to specify whether to enable the native columnar
csv reader")
+ .booleanConf
+ .createWithDefault(true)
Review Comment:
Should we make it `false` util we reached some milestone, e.g., when we
fixed `GlutenCSVSuite`?
##########
shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala:
##########
@@ -229,4 +230,6 @@ trait SparkShims {
def withAnsiEvalMode(expr: Expression): Boolean = false
+ def dateTimestampFormatInReadDefault(csvOptions: CSVOptions, timeZone:
String): Boolean
Review Comment:
Don't quite understand what the method name means. If no self-explaining
name can be used, could you add some comments for this method?
##########
gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala:
##########
@@ -86,4 +101,223 @@ object ArrowUtil extends Logging {
}
new Schema(fields)
}
+
+ def getFormat(format: String): FileFormat = {
+ format match {
+ case "parquet" => FileFormat.PARQUET
+ case "orc" => FileFormat.ORC
+ case "csv" => FileFormat.CSV
+ case _ => throw new IllegalArgumentException("Unrecognizable format")
+ }
+ }
+
+ def getFormat(format:
org.apache.spark.sql.execution.datasources.FileFormat): FileFormat = {
+ format match {
+ case _: ParquetFileFormat =>
+ FileFormat.PARQUET
+ case _: CSVFileFormat =>
+ FileFormat.CSV
+ case _ =>
+ throw new IllegalArgumentException("Unrecognizable format")
+ }
+ }
+
+ private def rewriteUri(encodeUri: String): String = {
+ val decodedUri = encodeUri
+ val uri = URI.create(decodedUri)
+ if (uri.getScheme == "s3" || uri.getScheme == "s3a") {
+ val s3Rewritten =
+ new URI("s3", uri.getAuthority, uri.getPath, uri.getQuery,
uri.getFragment).toString
+ return s3Rewritten
+ }
+ val sch = uri.getScheme match {
+ case "hdfs" => "hdfs"
+ case "file" => "file"
+ }
+ val ssp = uri.getScheme match {
+ case "hdfs" => uri.getSchemeSpecificPart
+ case "file" => "//" + uri.getSchemeSpecificPart
+ }
+ val rewritten = new URI(sch, ssp, uri.getFragment)
+ rewritten.toString
+ }
+
+ def makeArrowDiscovery(encodedUri: String, format: FileFormat):
FileSystemDatasetFactory = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val factory = new FileSystemDatasetFactory(
+ allocator,
+ NativeMemoryPool.getDefault, // TODO: wait to change
+ format,
+ rewriteUri(encodedUri))
+ factory
+ }
+
+ def readSchema(file: FileStatus, format: FileFormat): Option[StructType] = {
+ val factory: FileSystemDatasetFactory =
+ makeArrowDiscovery(file.getPath.toString, format)
+ val schema = factory.inspect()
+ try {
+ Option(SparkSchemaUtil.fromArrowSchema(schema))
+ } finally {
+ factory.close()
+ }
+ }
+
+ def readSchema(files: Seq[FileStatus], format: FileFormat):
Option[StructType] = {
+ if (files.isEmpty) {
+ throw new IllegalArgumentException("No input file specified")
+ }
+
+ readSchema(files.head, format)
+ }
+
+ def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean =
{
+ if (caseSensitive) { (str1: String, str2: String) => str1.equals(str2) }
+ else { (str1: String, str2: String) => str1.equalsIgnoreCase(str2) }
+ }
+
+ // If user specify schema by .schema(newSchemaDifferentWithFile)
+ def checkSchema(
+ requiredField: DataType,
+ parquetFileFieldType: ArrowType,
+ parquetFileFields: mutable.Buffer[Field]): Unit = {
+ val requiredFieldType =
+ SparkArrowUtil.toArrowType(requiredField,
SparkSchemaUtil.getLocalTimezoneID)
+ if (!requiredFieldType.equals(parquetFileFieldType)) {
+ val arrowFileSchema = parquetFileFields
+ .map(f => f.toString)
+ .reduceLeft((f1, f2) => f1 + "\n" + f2)
+ throw new SchemaMismatchException(
+ s"Not support specified schema is different with file
schema\n$arrowFileSchema")
+ }
+ }
+
+ def getRequestedField(
+ requiredSchema: StructType,
+ parquetFileFields: mutable.Buffer[Field],
+ caseSensitive: Boolean): Schema = {
+ val compareFunc = compareStringFunc(caseSensitive)
+ requiredSchema.foreach {
+ readField =>
+ // TODO: check schema inside of complex type
+ val matchedFields =
+ parquetFileFields.filter(field => compareFunc(field.getName,
readField.name))
+ if (!caseSensitive && matchedFields.size > 1) {
+ // Need to fail if there is ambiguity, i.e. more than one field is
matched
+ val fieldsString = matchedFields.map(_.getName).mkString("[", ", ",
"]")
+ throw new RuntimeException(
+ s"""
+ |Found duplicate field(s) "${readField.name}": $fieldsString
+
+ |in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
+ }
+ if (matchedFields.nonEmpty) {
+ checkSchema(
+ readField.dataType,
+ matchedFields.head.getFieldType.getType,
+ parquetFileFields)
+ }
+ }
+
+ val requestColNames = requiredSchema.map(_.name)
+ new Schema(parquetFileFields.filter {
+ field => requestColNames.exists(col => compareFunc(col, field.getName))
+ }.asJava)
+ }
+
+ def loadMissingColumns(
+ rowCount: Int,
+ missingSchema: StructType): Array[ArrowWritableColumnVector] = {
+
+ val vectors =
+ ArrowWritableColumnVector.allocateColumns(rowCount, missingSchema)
+ vectors.foreach {
+ vector =>
+ vector.putNulls(0, rowCount)
+ vector.setValueCount(rowCount)
+ }
+
+ SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+ (_: TaskContext) => {
+ vectors.foreach(_.close())
+ })
+
+ vectors
+ }
+
+ def loadPartitionColumns(
+ rowCount: Int,
+ partitionSchema: StructType,
+ partitionValues: InternalRow): Array[ArrowWritableColumnVector] = {
+ val partitionColumns = ArrowWritableColumnVector.allocateColumns(rowCount,
partitionSchema)
+ (0 until partitionColumns.length).foreach(
+ i => {
+ ArrowColumnVectorUtils.populate(partitionColumns(i), partitionValues,
i)
+ partitionColumns(i).setValueCount(rowCount)
+ partitionColumns(i).setIsConstant()
+ })
+
+ SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+ (_: TaskContext) => {
+ partitionColumns.foreach(_.close())
+ })
Review Comment:
ditto
##########
gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala:
##########
@@ -86,4 +101,223 @@ object ArrowUtil extends Logging {
}
new Schema(fields)
}
+
+ def getFormat(format: String): FileFormat = {
+ format match {
+ case "parquet" => FileFormat.PARQUET
+ case "orc" => FileFormat.ORC
+ case "csv" => FileFormat.CSV
+ case _ => throw new IllegalArgumentException("Unrecognizable format")
+ }
+ }
+
+ def getFormat(format:
org.apache.spark.sql.execution.datasources.FileFormat): FileFormat = {
+ format match {
+ case _: ParquetFileFormat =>
+ FileFormat.PARQUET
+ case _: CSVFileFormat =>
+ FileFormat.CSV
+ case _ =>
+ throw new IllegalArgumentException("Unrecognizable format")
+ }
+ }
+
+ private def rewriteUri(encodeUri: String): String = {
+ val decodedUri = encodeUri
+ val uri = URI.create(decodedUri)
+ if (uri.getScheme == "s3" || uri.getScheme == "s3a") {
+ val s3Rewritten =
+ new URI("s3", uri.getAuthority, uri.getPath, uri.getQuery,
uri.getFragment).toString
+ return s3Rewritten
+ }
+ val sch = uri.getScheme match {
+ case "hdfs" => "hdfs"
+ case "file" => "file"
+ }
+ val ssp = uri.getScheme match {
+ case "hdfs" => uri.getSchemeSpecificPart
+ case "file" => "//" + uri.getSchemeSpecificPart
+ }
+ val rewritten = new URI(sch, ssp, uri.getFragment)
+ rewritten.toString
+ }
+
+ def makeArrowDiscovery(encodedUri: String, format: FileFormat):
FileSystemDatasetFactory = {
+ val allocator = ArrowBufferAllocators.contextInstance()
+ val factory = new FileSystemDatasetFactory(
+ allocator,
+ NativeMemoryPool.getDefault, // TODO: wait to change
+ format,
+ rewriteUri(encodedUri))
+ factory
+ }
+
+ def readSchema(file: FileStatus, format: FileFormat): Option[StructType] = {
+ val factory: FileSystemDatasetFactory =
+ makeArrowDiscovery(file.getPath.toString, format)
+ val schema = factory.inspect()
+ try {
+ Option(SparkSchemaUtil.fromArrowSchema(schema))
+ } finally {
+ factory.close()
+ }
+ }
+
+ def readSchema(files: Seq[FileStatus], format: FileFormat):
Option[StructType] = {
+ if (files.isEmpty) {
+ throw new IllegalArgumentException("No input file specified")
+ }
+
+ readSchema(files.head, format)
+ }
+
+ def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean =
{
+ if (caseSensitive) { (str1: String, str2: String) => str1.equals(str2) }
+ else { (str1: String, str2: String) => str1.equalsIgnoreCase(str2) }
+ }
+
+ // If user specify schema by .schema(newSchemaDifferentWithFile)
+ def checkSchema(
+ requiredField: DataType,
+ parquetFileFieldType: ArrowType,
+ parquetFileFields: mutable.Buffer[Field]): Unit = {
+ val requiredFieldType =
+ SparkArrowUtil.toArrowType(requiredField,
SparkSchemaUtil.getLocalTimezoneID)
+ if (!requiredFieldType.equals(parquetFileFieldType)) {
+ val arrowFileSchema = parquetFileFields
+ .map(f => f.toString)
+ .reduceLeft((f1, f2) => f1 + "\n" + f2)
+ throw new SchemaMismatchException(
+ s"Not support specified schema is different with file
schema\n$arrowFileSchema")
+ }
+ }
+
+ def getRequestedField(
+ requiredSchema: StructType,
+ parquetFileFields: mutable.Buffer[Field],
+ caseSensitive: Boolean): Schema = {
+ val compareFunc = compareStringFunc(caseSensitive)
+ requiredSchema.foreach {
+ readField =>
+ // TODO: check schema inside of complex type
+ val matchedFields =
+ parquetFileFields.filter(field => compareFunc(field.getName,
readField.name))
+ if (!caseSensitive && matchedFields.size > 1) {
+ // Need to fail if there is ambiguity, i.e. more than one field is
matched
+ val fieldsString = matchedFields.map(_.getName).mkString("[", ", ",
"]")
+ throw new RuntimeException(
+ s"""
+ |Found duplicate field(s) "${readField.name}": $fieldsString
+
+ |in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
+ }
+ if (matchedFields.nonEmpty) {
+ checkSchema(
+ readField.dataType,
+ matchedFields.head.getFieldType.getType,
+ parquetFileFields)
+ }
+ }
+
+ val requestColNames = requiredSchema.map(_.name)
+ new Schema(parquetFileFields.filter {
+ field => requestColNames.exists(col => compareFunc(col, field.getName))
+ }.asJava)
+ }
+
+ def loadMissingColumns(
+ rowCount: Int,
+ missingSchema: StructType): Array[ArrowWritableColumnVector] = {
+
+ val vectors =
+ ArrowWritableColumnVector.allocateColumns(rowCount, missingSchema)
+ vectors.foreach {
+ vector =>
+ vector.putNulls(0, rowCount)
+ vector.setValueCount(rowCount)
+ }
+
+ SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+ (_: TaskContext) => {
+ vectors.foreach(_.close())
+ })
Review Comment:
Is there a way to avoid closing these vectors through task completion
listener?
##########
gluten-core/src/main/scala/org/apache/gluten/utils/PlanUtil.scala:
##########
@@ -70,6 +70,9 @@ object PlanUtil {
case _: InputAdapter => false
case _: WholeStageCodegenExec => false
case r: ReusedExchangeExec => isVanillaColumnarOp(r.child)
+ case f: FileSourceScanExec =>
+
!f.relation.fileFormat.getClass.getSimpleName.equals("ArrowCSVFileFormat") &&
+ f.supportsColumnar
Review Comment:
Would you like to explain this change? Thanks.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]