This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3bc374d945ff [SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by 
Invoke & RuntimeReplaceable)
3bc374d945ff is described below

commit 3bc374d945ff91cda78e64c1d63fe9a95f735ebf
Author: panbingkun <[email protected]>
AuthorDate: Thu Nov 21 09:09:24 2024 +0100

    [SPARK-50333][SQL] Codegen Support for `CsvToStructs` (by Invoke & 
RuntimeReplaceable)
    
    ### What changes were proposed in this pull request?
    The pr aims to add `Codegen` Support for `CsvToStructs`(`from_csv`).
    
    ### Why are the changes needed?
    - improve codegen coverage.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Pass GA & Existed UT (eg: CsvFunctionsSuite#`*from_csv*`)
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48873 from panbingkun/from_csv_codegen.
    
    Lead-authored-by: panbingkun <[email protected]>
    Co-authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../expressions/csv/CsvExpressionEvalUtils.scala   | 70 ++++++++++++++++-
 .../sql/catalyst/expressions/csvExpressions.scala  | 87 ++++++----------------
 .../explain-results/function_from_csv.explain      |  2 +-
 3 files changed, 93 insertions(+), 66 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
index abd0703fa7d7..a91e4ab13001 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala
@@ -18,10 +18,78 @@ package org.apache.spark.sql.catalyst.expressions.csv
 
 import com.univocity.parsers.csv.CsvParser
 
-import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions}
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions, 
UnivocityParser}
+import org.apache.spark.sql.catalyst.expressions.ExprUtils
+import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, 
PermissiveMode}
+import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.types.{DataType, NullType, StructType}
 import org.apache.spark.unsafe.types.UTF8String
 
+/**
+ * The expression `CsvToStructs` will utilize the `Invoke` to call it, support 
codegen.
+ */
+case class CsvToStructsEvaluator(
+    options: Map[String, String],
+    nullableSchema: StructType,
+    nameOfCorruptRecord: String,
+    timeZoneId: Option[String],
+    requiredSchema: Option[StructType]) {
+
+  // This converts parsed rows to the desired output by the given schema.
+  @transient
+  private lazy val converter = (rows: Iterator[InternalRow]) => {
+    if (!rows.hasNext) {
+      throw SparkException.internalError("Expected one row from CSV parser.")
+    }
+    val result = rows.next()
+    // CSV's parser produces one record only.
+    assert(!rows.hasNext)
+    result
+  }
+
+  @transient
+  private lazy val parser = {
+    // 'lineSep' is a plan-wise option so we set a noncharacter, according to
+    // the unicode specification, which should not appear in Java's strings.
+    // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
+    // scalastyle:off nonascii
+    val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
+    // scalastyle:on nonascii
+    val parsedOptions = new CSVOptions(
+      exprOptions,
+      columnPruning = true,
+      defaultTimeZoneId = timeZoneId.get,
+      defaultColumnNameOfCorruptRecord = nameOfCorruptRecord)
+    val mode = parsedOptions.parseMode
+    if (mode != PermissiveMode && mode != FailFastMode) {
+      throw QueryCompilationErrors.parseModeUnsupportedError("from_csv", mode)
+    }
+    ExprUtils.verifyColumnNameOfCorruptRecord(
+      nullableSchema,
+      parsedOptions.columnNameOfCorruptRecord)
+
+    val actualSchema =
+      StructType(nullableSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
+    val actualRequiredSchema =
+      StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
+        .filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
+    val rawParser = new UnivocityParser(actualSchema,
+      actualRequiredSchema,
+      parsedOptions)
+    new FailureSafeParser[String](
+      input => rawParser.parse(input),
+      mode,
+      nullableSchema,
+      parsedOptions.columnNameOfCorruptRecord)
+  }
+
+  final def evaluate(csv: UTF8String): InternalRow = {
+    converter(parser.parse(csv.toString))
+  }
+}
+
 case class SchemaOfCsvEvaluator(options: Map[String, String]) {
 
   @transient
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
index e9cdc184e55a..02e5488835c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala
@@ -19,17 +19,16 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.io.CharArrayWriter
 
-import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import 
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, 
TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.csv._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
CodegenFallback, ExprCode}
-import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.csv.{CsvToStructsEvaluator, 
SchemaOfCsvEvaluator}
 import org.apache.spark.sql.catalyst.expressions.objects.Invoke
-import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, 
TreePattern}
 import org.apache.spark.sql.catalyst.util.TypeUtils._
-import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
+import org.apache.spark.sql.errors.QueryErrorsBase
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.types.StringTypeWithCollation
 import org.apache.spark.sql.types._
@@ -58,15 +57,17 @@ case class CsvToStructs(
     timeZoneId: Option[String] = None,
     requiredSchema: Option[StructType] = None)
   extends UnaryExpression
-    with TimeZoneAwareExpression
-    with CodegenFallback
-    with ExpectsInputTypes {
-  override def nullIntolerant: Boolean = true
+  with RuntimeReplaceable
+  with ExpectsInputTypes
+  with TimeZoneAwareExpression {
+
   override def nullable: Boolean = child.nullable
 
+  override def nodePatternsInternal(): Seq[TreePattern] = 
Seq(RUNTIME_REPLACEABLE)
+
   // The CSV input data might be missing certain fields. We force the 
nullability
   // of the user-provided schema to avoid data corruptions.
-  val nullableSchema: StructType = schema.asNullable
+  private val nullableSchema: StructType = schema.asNullable
 
   // Used in `FunctionRegistry`
   def this(child: Expression, schema: Expression, options: Map[String, 
String]) =
@@ -85,55 +86,7 @@ case class CsvToStructs(
       child = child,
       timeZoneId = None)
 
-  // This converts parsed rows to the desired output by the given schema.
-  @transient
-  lazy val converter = (rows: Iterator[InternalRow]) => {
-    if (rows.hasNext) {
-      val result = rows.next()
-      // CSV's parser produces one record only.
-      assert(!rows.hasNext)
-      result
-    } else {
-      throw SparkException.internalError("Expected one row from CSV parser.")
-    }
-  }
-
-  val nameOfCorruptRecord = 
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
-
-  @transient lazy val parser = {
-    // 'lineSep' is a plan-wise option so we set a noncharacter, according to
-    // the unicode specification, which should not appear in Java's strings.
-    // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
-    // scalastyle:off nonascii
-    val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
-    // scalastyle:on nonascii
-    val parsedOptions = new CSVOptions(
-      exprOptions,
-      columnPruning = true,
-      defaultTimeZoneId = timeZoneId.get,
-      defaultColumnNameOfCorruptRecord = nameOfCorruptRecord)
-    val mode = parsedOptions.parseMode
-    if (mode != PermissiveMode && mode != FailFastMode) {
-      throw QueryCompilationErrors.parseModeUnsupportedError("from_csv", mode)
-    }
-    ExprUtils.verifyColumnNameOfCorruptRecord(
-      nullableSchema,
-      parsedOptions.columnNameOfCorruptRecord)
-
-    val actualSchema =
-      StructType(nullableSchema.filterNot(_.name == 
parsedOptions.columnNameOfCorruptRecord))
-    val actualRequiredSchema =
-      StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
-        .filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
-    val rawParser = new UnivocityParser(actualSchema,
-      actualRequiredSchema,
-      parsedOptions)
-    new FailureSafeParser[String](
-      input => rawParser.parse(input),
-      mode,
-      nullableSchema,
-      parsedOptions.columnNameOfCorruptRecord)
-  }
+  private val nameOfCorruptRecord = 
SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
 
   override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable
 
@@ -141,15 +94,21 @@ case class CsvToStructs(
     copy(timeZoneId = Option(timeZoneId))
   }
 
-  override def nullSafeEval(input: Any): Any = {
-    val csv = input.asInstanceOf[UTF8String].toString
-    converter(parser.parse(csv))
-  }
-
   override def inputTypes: Seq[AbstractDataType] = StringTypeWithCollation :: 
Nil
 
   override def prettyName: String = "from_csv"
 
+  @transient
+  private lazy val evaluator: CsvToStructsEvaluator = CsvToStructsEvaluator(
+    options, nullableSchema, nameOfCorruptRecord, timeZoneId, requiredSchema)
+
+  override def replacement: Expression = Invoke(
+    Literal.create(evaluator, ObjectType(classOf[CsvToStructsEvaluator])),
+    "evaluate",
+    dataType,
+    Seq(child),
+    Seq(child.dataType))
+
   override protected def withNewChildInternal(newChild: Expression): 
CsvToStructs =
     copy(child = newChild)
 }
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
index 89e03c818823..ef87c18948b2 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_csv.explain
@@ -1,2 +1,2 @@
-Project [from_csv(StructField(id,LongType,true), 
StructField(a,IntegerType,true), StructField(b,DoubleType,true), 
(mode,FAILFAST), g#0, Some(America/Los_Angeles), None) AS from_csv(g)#0]
+Project [invoke(CsvToStructsEvaluator(Map(mode -> 
FAILFAST),StructType(StructField(id,LongType,true),StructField(a,IntegerType,true),StructField(b,DoubleType,true)),_corrupt_record,Some(America/Los_Angeles),None).evaluate(g#0))
 AS from_csv(g)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to