zhztheplayer commented on code in PR #5447:
URL: https://github.com/apache/incubator-gluten/pull/5447#discussion_r1578836085


##########
shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala:
##########
@@ -1460,6 +1462,13 @@ object GlutenConfig {
       .booleanConf
       .createOptional
 
+  val NATIVE_ARROW_READER_ENABLED =
+    buildConf("spark.gluten.sql.native.arrow.reader.enabled")
+      .internal()
+      .doc("This is config to specify whether to enable the native columnar 
csv reader")
+      .booleanConf
+      .createWithDefault(true)

Review Comment:
   Should we make it `false` util we reached some milestone, e.g., when we 
fixed `GlutenCSVSuite`?



##########
shims/common/src/main/scala/org/apache/gluten/sql/shims/SparkShims.scala:
##########
@@ -229,4 +230,6 @@ trait SparkShims {
 
   def withAnsiEvalMode(expr: Expression): Boolean = false
 
+  def dateTimestampFormatInReadDefault(csvOptions: CSVOptions, timeZone: 
String): Boolean

Review Comment:
   Don't quite understand what the method name means. If no self-explaining 
name can be used, could you add some comments for this method?



##########
gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala:
##########
@@ -86,4 +101,223 @@ object ArrowUtil extends Logging {
     }
     new Schema(fields)
   }
+
+  def getFormat(format: String): FileFormat = {
+    format match {
+      case "parquet" => FileFormat.PARQUET
+      case "orc" => FileFormat.ORC
+      case "csv" => FileFormat.CSV
+      case _ => throw new IllegalArgumentException("Unrecognizable format")
+    }
+  }
+
+  def getFormat(format: 
org.apache.spark.sql.execution.datasources.FileFormat): FileFormat = {
+    format match {
+      case _: ParquetFileFormat =>
+        FileFormat.PARQUET
+      case _: CSVFileFormat =>
+        FileFormat.CSV
+      case _ =>
+        throw new IllegalArgumentException("Unrecognizable format")
+    }
+  }
+
+  private def rewriteUri(encodeUri: String): String = {
+    val decodedUri = encodeUri
+    val uri = URI.create(decodedUri)
+    if (uri.getScheme == "s3" || uri.getScheme == "s3a") {
+      val s3Rewritten =
+        new URI("s3", uri.getAuthority, uri.getPath, uri.getQuery, 
uri.getFragment).toString
+      return s3Rewritten
+    }
+    val sch = uri.getScheme match {
+      case "hdfs" => "hdfs"
+      case "file" => "file"
+    }
+    val ssp = uri.getScheme match {
+      case "hdfs" => uri.getSchemeSpecificPart
+      case "file" => "//" + uri.getSchemeSpecificPart
+    }
+    val rewritten = new URI(sch, ssp, uri.getFragment)
+    rewritten.toString
+  }
+
+  def makeArrowDiscovery(encodedUri: String, format: FileFormat): 
FileSystemDatasetFactory = {
+    val allocator = ArrowBufferAllocators.contextInstance()
+    val factory = new FileSystemDatasetFactory(
+      allocator,
+      NativeMemoryPool.getDefault, // TODO: wait to change
+      format,
+      rewriteUri(encodedUri))
+    factory
+  }
+
+  def readSchema(file: FileStatus, format: FileFormat): Option[StructType] = {
+    val factory: FileSystemDatasetFactory =
+      makeArrowDiscovery(file.getPath.toString, format)
+    val schema = factory.inspect()
+    try {
+      Option(SparkSchemaUtil.fromArrowSchema(schema))
+    } finally {
+      factory.close()
+    }
+  }
+
+  def readSchema(files: Seq[FileStatus], format: FileFormat): 
Option[StructType] = {
+    if (files.isEmpty) {
+      throw new IllegalArgumentException("No input file specified")
+    }
+
+    readSchema(files.head, format)
+  }
+
+  def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean = 
{
+    if (caseSensitive) { (str1: String, str2: String) => str1.equals(str2) }
+    else { (str1: String, str2: String) => str1.equalsIgnoreCase(str2) }
+  }
+
+  // If user specify schema by .schema(newSchemaDifferentWithFile)
+  def checkSchema(
+      requiredField: DataType,
+      parquetFileFieldType: ArrowType,
+      parquetFileFields: mutable.Buffer[Field]): Unit = {
+    val requiredFieldType =
+      SparkArrowUtil.toArrowType(requiredField, 
SparkSchemaUtil.getLocalTimezoneID)
+    if (!requiredFieldType.equals(parquetFileFieldType)) {
+      val arrowFileSchema = parquetFileFields
+        .map(f => f.toString)
+        .reduceLeft((f1, f2) => f1 + "\n" + f2)
+      throw new SchemaMismatchException(
+        s"Not support specified schema is different with file 
schema\n$arrowFileSchema")
+    }
+  }
+
+  def getRequestedField(
+      requiredSchema: StructType,
+      parquetFileFields: mutable.Buffer[Field],
+      caseSensitive: Boolean): Schema = {
+    val compareFunc = compareStringFunc(caseSensitive)
+    requiredSchema.foreach {
+      readField =>
+        // TODO: check schema inside of complex type
+        val matchedFields =
+          parquetFileFields.filter(field => compareFunc(field.getName, 
readField.name))
+        if (!caseSensitive && matchedFields.size > 1) {
+          // Need to fail if there is ambiguity, i.e. more than one field is 
matched
+          val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", 
"]")
+          throw new RuntimeException(
+            s"""
+               |Found duplicate field(s) "${readField.name}": $fieldsString
+
+               |in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
+        }
+        if (matchedFields.nonEmpty) {
+          checkSchema(
+            readField.dataType,
+            matchedFields.head.getFieldType.getType,
+            parquetFileFields)
+        }
+    }
+
+    val requestColNames = requiredSchema.map(_.name)
+    new Schema(parquetFileFields.filter {
+      field => requestColNames.exists(col => compareFunc(col, field.getName))
+    }.asJava)
+  }
+
+  def loadMissingColumns(
+      rowCount: Int,
+      missingSchema: StructType): Array[ArrowWritableColumnVector] = {
+
+    val vectors =
+      ArrowWritableColumnVector.allocateColumns(rowCount, missingSchema)
+    vectors.foreach {
+      vector =>
+        vector.putNulls(0, rowCount)
+        vector.setValueCount(rowCount)
+    }
+
+    SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+      (_: TaskContext) => {
+        vectors.foreach(_.close())
+      })
+
+    vectors
+  }
+
+  def loadPartitionColumns(
+      rowCount: Int,
+      partitionSchema: StructType,
+      partitionValues: InternalRow): Array[ArrowWritableColumnVector] = {
+    val partitionColumns = ArrowWritableColumnVector.allocateColumns(rowCount, 
partitionSchema)
+    (0 until partitionColumns.length).foreach(
+      i => {
+        ArrowColumnVectorUtils.populate(partitionColumns(i), partitionValues, 
i)
+        partitionColumns(i).setValueCount(rowCount)
+        partitionColumns(i).setIsConstant()
+      })
+
+    SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+      (_: TaskContext) => {
+        partitionColumns.foreach(_.close())
+      })

Review Comment:
   ditto



##########
gluten-data/src/main/scala/org/apache/gluten/utils/ArrowUtil.scala:
##########
@@ -86,4 +101,223 @@ object ArrowUtil extends Logging {
     }
     new Schema(fields)
   }
+
+  def getFormat(format: String): FileFormat = {
+    format match {
+      case "parquet" => FileFormat.PARQUET
+      case "orc" => FileFormat.ORC
+      case "csv" => FileFormat.CSV
+      case _ => throw new IllegalArgumentException("Unrecognizable format")
+    }
+  }
+
+  def getFormat(format: 
org.apache.spark.sql.execution.datasources.FileFormat): FileFormat = {
+    format match {
+      case _: ParquetFileFormat =>
+        FileFormat.PARQUET
+      case _: CSVFileFormat =>
+        FileFormat.CSV
+      case _ =>
+        throw new IllegalArgumentException("Unrecognizable format")
+    }
+  }
+
+  private def rewriteUri(encodeUri: String): String = {
+    val decodedUri = encodeUri
+    val uri = URI.create(decodedUri)
+    if (uri.getScheme == "s3" || uri.getScheme == "s3a") {
+      val s3Rewritten =
+        new URI("s3", uri.getAuthority, uri.getPath, uri.getQuery, 
uri.getFragment).toString
+      return s3Rewritten
+    }
+    val sch = uri.getScheme match {
+      case "hdfs" => "hdfs"
+      case "file" => "file"
+    }
+    val ssp = uri.getScheme match {
+      case "hdfs" => uri.getSchemeSpecificPart
+      case "file" => "//" + uri.getSchemeSpecificPart
+    }
+    val rewritten = new URI(sch, ssp, uri.getFragment)
+    rewritten.toString
+  }
+
+  def makeArrowDiscovery(encodedUri: String, format: FileFormat): 
FileSystemDatasetFactory = {
+    val allocator = ArrowBufferAllocators.contextInstance()
+    val factory = new FileSystemDatasetFactory(
+      allocator,
+      NativeMemoryPool.getDefault, // TODO: wait to change
+      format,
+      rewriteUri(encodedUri))
+    factory
+  }
+
+  def readSchema(file: FileStatus, format: FileFormat): Option[StructType] = {
+    val factory: FileSystemDatasetFactory =
+      makeArrowDiscovery(file.getPath.toString, format)
+    val schema = factory.inspect()
+    try {
+      Option(SparkSchemaUtil.fromArrowSchema(schema))
+    } finally {
+      factory.close()
+    }
+  }
+
+  def readSchema(files: Seq[FileStatus], format: FileFormat): 
Option[StructType] = {
+    if (files.isEmpty) {
+      throw new IllegalArgumentException("No input file specified")
+    }
+
+    readSchema(files.head, format)
+  }
+
+  def compareStringFunc(caseSensitive: Boolean): (String, String) => Boolean = 
{
+    if (caseSensitive) { (str1: String, str2: String) => str1.equals(str2) }
+    else { (str1: String, str2: String) => str1.equalsIgnoreCase(str2) }
+  }
+
+  // If user specify schema by .schema(newSchemaDifferentWithFile)
+  def checkSchema(
+      requiredField: DataType,
+      parquetFileFieldType: ArrowType,
+      parquetFileFields: mutable.Buffer[Field]): Unit = {
+    val requiredFieldType =
+      SparkArrowUtil.toArrowType(requiredField, 
SparkSchemaUtil.getLocalTimezoneID)
+    if (!requiredFieldType.equals(parquetFileFieldType)) {
+      val arrowFileSchema = parquetFileFields
+        .map(f => f.toString)
+        .reduceLeft((f1, f2) => f1 + "\n" + f2)
+      throw new SchemaMismatchException(
+        s"Not support specified schema is different with file 
schema\n$arrowFileSchema")
+    }
+  }
+
+  def getRequestedField(
+      requiredSchema: StructType,
+      parquetFileFields: mutable.Buffer[Field],
+      caseSensitive: Boolean): Schema = {
+    val compareFunc = compareStringFunc(caseSensitive)
+    requiredSchema.foreach {
+      readField =>
+        // TODO: check schema inside of complex type
+        val matchedFields =
+          parquetFileFields.filter(field => compareFunc(field.getName, 
readField.name))
+        if (!caseSensitive && matchedFields.size > 1) {
+          // Need to fail if there is ambiguity, i.e. more than one field is 
matched
+          val fieldsString = matchedFields.map(_.getName).mkString("[", ", ", 
"]")
+          throw new RuntimeException(
+            s"""
+               |Found duplicate field(s) "${readField.name}": $fieldsString
+
+               |in case-insensitive mode""".stripMargin.replaceAll("\n", " "))
+        }
+        if (matchedFields.nonEmpty) {
+          checkSchema(
+            readField.dataType,
+            matchedFields.head.getFieldType.getType,
+            parquetFileFields)
+        }
+    }
+
+    val requestColNames = requiredSchema.map(_.name)
+    new Schema(parquetFileFields.filter {
+      field => requestColNames.exists(col => compareFunc(col, field.getName))
+    }.asJava)
+  }
+
+  def loadMissingColumns(
+      rowCount: Int,
+      missingSchema: StructType): Array[ArrowWritableColumnVector] = {
+
+    val vectors =
+      ArrowWritableColumnVector.allocateColumns(rowCount, missingSchema)
+    vectors.foreach {
+      vector =>
+        vector.putNulls(0, rowCount)
+        vector.setValueCount(rowCount)
+    }
+
+    SparkMemoryUtil.addLeakSafeTaskCompletionListener(
+      (_: TaskContext) => {
+        vectors.foreach(_.close())
+      })

Review Comment:
   Is there a way to avoid closing these vectors through task completion 
listener?



##########
gluten-core/src/main/scala/org/apache/gluten/utils/PlanUtil.scala:
##########
@@ -70,6 +70,9 @@ object PlanUtil {
       case _: InputAdapter => false
       case _: WholeStageCodegenExec => false
       case r: ReusedExchangeExec => isVanillaColumnarOp(r.child)
+      case f: FileSourceScanExec =>
+        
!f.relation.fileFormat.getClass.getSimpleName.equals("ArrowCSVFileFormat") &&
+        f.supportsColumnar

Review Comment:
   Would you like to explain this change? Thanks.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to