This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 98ae33b7acb [SPARK-43142] Fix DSL expressions on attributes with 
special characters
98ae33b7acb is described below

commit 98ae33b7acbd932714301c83d71d42ef318dda9b
Author: Willi Raschkowski <wraschkow...@palantir.com>
AuthorDate: Tue Apr 25 18:23:36 2023 +0800

    [SPARK-43142] Fix DSL expressions on attributes with special characters
    
    Re-attempting #40794. #40794 tried to more safely create 
`AttributeReference` objects from multi-part attributes in `ImplicitAttribute`. 
But that broke things and we had to revert. This PR is limiting the fix to the 
`UnresolvedAttribute` object returned by `DslAttr.attr`, which is enough to fix 
the issue here.
    
    ### What changes were proposed in this pull request?
    This PR fixes DSL expressions on attributes with special characters by 
making `DslAttr.attr` and `DslAttr.expr` return the implicitly wrapped 
attribute instead of creating a new one.
    
    ### Why are the changes needed?
    SPARK-43142: DSL expressions on attributes with special characters don't 
work even if the attribute names are quoted:
    
    ```scala
    scala> "`slashed/col`".attr
    res0: org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute = 
'slashed/col
    
    scala> "`slashed/col`".attr.asc
    org.apache.spark.sql.catalyst.parser.ParseException:
    mismatched input '/' expecting {<EOF>, '.', '-'}(line 1, pos 7)
    
    == SQL ==
    slashed/col
    -------^^^
    ```
    
    DSL expressions rely on a call to `expr` to get child of the new expression 
[(e.g.)](https://github.com/apache/spark/blob/87a5442f7ed96b11051d8a9333476d080054e5a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala#L149).
    
    `expr` here is a call on implicit class `DslAttr` that's wrapping the 
`UnresolvedAttribute` returned by `"...".attr` is wrapped by the implicit class 
`DslAttr`.
    
    `DslAttr` and its super class implement `DslAttr.expr` such that a new 
`UnresolvedAttribute` is created from `UnresolvedAttribute.name` of the wrapped 
attribute 
[(here)](https://github.com/apache/spark/blob/87a5442f7ed96b11051d8a9333476d080054e5a0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala#L273-L280).
    
    But `UnresolvedAttribute.name` drops the quotes and thus the newly created 
`UnresolvedAttribute` parses an identifier that should be quoted but isn't:
    ```scala
    scala> "`col/slash`".attr.name
    res5: String = col/slash
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    DSL expressions on attributes with special characters no longer fail.
    
    ### How was this patch tested?
    I couldn't find a suite testing the implicit classes in the DSL package, 
but the DSL package seems used widely enough that I'm confident this doesn't 
break existing behavior.
    
    Locally, I was able to reproduce with this test; it was failing before and 
passes now:
    ```scala
    test("chained DSL expressions on attributes with special characters") {
      $"`slashed/col`".asc
    }
    ```
    
    Closes #40902 from rshkv/wr/spark-43142-v2.
    
    Authored-by: Willi Raschkowski <wraschkow...@palantir.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/catalyst/dsl/package.scala    | 73 ++++++++++------------
 .../expressions/ExpressionSQLBuilderSuite.scala    |  2 +-
 .../datasources/DataSourceStrategySuite.scala      | 12 ++--
 .../datasources/v2/DataSourceV2StrategySuite.scala | 12 ++--
 4 files changed, 46 insertions(+), 53 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index ac439203cb7..27d05f3bac7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -271,8 +271,8 @@ package object dsl {
       override def expr: Expression = Literal(s)
       def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
     }
-    implicit class DslAttr(attr: UnresolvedAttribute) extends 
ImplicitAttribute {
-      def s: String = attr.name
+    implicit class DslAttr(override val attr: UnresolvedAttribute) extends 
ImplicitAttribute {
+      def s: String = attr.sql
     }
 
     abstract class ImplicitAttribute extends ImplicitOperators {
@@ -280,90 +280,83 @@ package object dsl {
       def expr: UnresolvedAttribute = attr
       def attr: UnresolvedAttribute = analysis.UnresolvedAttribute(s)
 
+      private def attrRef(dataType: DataType): AttributeReference =
+        AttributeReference(attr.nameParts.last, dataType)(qualifier = 
attr.nameParts.init)
+
       /** Creates a new AttributeReference of type boolean */
-      def boolean: AttributeReference = AttributeReference(s, BooleanType, 
nullable = true)()
+      def boolean: AttributeReference = attrRef(BooleanType)
 
       /** Creates a new AttributeReference of type byte */
-      def byte: AttributeReference = AttributeReference(s, ByteType, nullable 
= true)()
+      def byte: AttributeReference = attrRef(ByteType)
 
       /** Creates a new AttributeReference of type short */
-      def short: AttributeReference = AttributeReference(s, ShortType, 
nullable = true)()
+      def short: AttributeReference = attrRef(ShortType)
 
       /** Creates a new AttributeReference of type int */
-      def int: AttributeReference = AttributeReference(s, IntegerType, 
nullable = true)()
+      def int: AttributeReference = attrRef(IntegerType)
 
       /** Creates a new AttributeReference of type long */
-      def long: AttributeReference = AttributeReference(s, LongType, nullable 
= true)()
+      def long: AttributeReference = attrRef(LongType)
 
       /** Creates a new AttributeReference of type float */
-      def float: AttributeReference = AttributeReference(s, FloatType, 
nullable = true)()
+      def float: AttributeReference = attrRef(FloatType)
 
       /** Creates a new AttributeReference of type double */
-      def double: AttributeReference = AttributeReference(s, DoubleType, 
nullable = true)()
+      def double: AttributeReference = attrRef(DoubleType)
 
       /** Creates a new AttributeReference of type string */
-      def string: AttributeReference = AttributeReference(s, StringType, 
nullable = true)()
+      def string: AttributeReference = attrRef(StringType)
 
       /** Creates a new AttributeReference of type date */
-      def date: AttributeReference = AttributeReference(s, DateType, nullable 
= true)()
+      def date: AttributeReference = attrRef(DateType)
 
       /** Creates a new AttributeReference of type decimal */
-      def decimal: AttributeReference =
-        AttributeReference(s, DecimalType.SYSTEM_DEFAULT, nullable = true)()
+      def decimal: AttributeReference = attrRef(DecimalType.SYSTEM_DEFAULT)
 
       /** Creates a new AttributeReference of type decimal */
       def decimal(precision: Int, scale: Int): AttributeReference =
-        AttributeReference(s, DecimalType(precision, scale), nullable = true)()
+        attrRef(DecimalType(precision, scale))
 
       /** Creates a new AttributeReference of type timestamp */
-      def timestamp: AttributeReference = AttributeReference(s, TimestampType, 
nullable = true)()
+      def timestamp: AttributeReference = attrRef(TimestampType)
 
       /** Creates a new AttributeReference of type timestamp without time zone 
*/
-      def timestampNTZ: AttributeReference =
-        AttributeReference(s, TimestampNTZType, nullable = true)()
+      def timestampNTZ: AttributeReference = attrRef(TimestampNTZType)
 
       /** Creates a new AttributeReference of the day-time interval type */
-      def dayTimeInterval(startField: Byte, endField: Byte): 
AttributeReference = {
-        AttributeReference(s, DayTimeIntervalType(startField, endField), 
nullable = true)()
-      }
-      def dayTimeInterval(): AttributeReference = {
-        AttributeReference(s, DayTimeIntervalType(), nullable = true)()
-      }
+      def dayTimeInterval(startField: Byte, endField: Byte): 
AttributeReference =
+        attrRef(DayTimeIntervalType(startField, endField))
+
+      def dayTimeInterval(): AttributeReference = 
attrRef(DayTimeIntervalType())
 
       /** Creates a new AttributeReference of the year-month interval type */
-      def yearMonthInterval(startField: Byte, endField: Byte): 
AttributeReference = {
-        AttributeReference(s, YearMonthIntervalType(startField, endField), 
nullable = true)()
-      }
-      def yearMonthInterval(): AttributeReference = {
-        AttributeReference(s, YearMonthIntervalType(), nullable = true)()
-      }
+      def yearMonthInterval(startField: Byte, endField: Byte): 
AttributeReference =
+        attrRef(YearMonthIntervalType(startField, endField))
+
+      def yearMonthInterval(): AttributeReference = 
attrRef(YearMonthIntervalType())
 
       /** Creates a new AttributeReference of type binary */
-      def binary: AttributeReference = AttributeReference(s, BinaryType, 
nullable = true)()
+      def binary: AttributeReference = attrRef(BinaryType)
 
       /** Creates a new AttributeReference of type array */
-      def array(dataType: DataType): AttributeReference =
-        AttributeReference(s, ArrayType(dataType), nullable = true)()
+      def array(dataType: DataType): AttributeReference = 
attrRef(ArrayType(dataType))
 
-      def array(arrayType: ArrayType): AttributeReference =
-        AttributeReference(s, arrayType)()
+      def array(arrayType: ArrayType): AttributeReference = attrRef(arrayType)
 
       /** Creates a new AttributeReference of type map */
       def map(keyType: DataType, valueType: DataType): AttributeReference =
         map(MapType(keyType, valueType))
 
-      def map(mapType: MapType): AttributeReference =
-        AttributeReference(s, mapType, nullable = true)()
+      def map(mapType: MapType): AttributeReference = attrRef(mapType)
 
       /** Creates a new AttributeReference of type struct */
-      def struct(structType: StructType): AttributeReference =
-        AttributeReference(s, structType, nullable = true)()
+      def struct(structType: StructType): AttributeReference = 
attrRef(structType)
+
       def struct(attrs: AttributeReference*): AttributeReference =
         struct(StructType.fromAttributes(attrs))
 
       /** Creates a new AttributeReference of object type */
-      def obj(cls: Class[_]): AttributeReference =
-        AttributeReference(s, ObjectType(cls), nullable = true)()
+      def obj(cls: Class[_]): AttributeReference = attrRef(ObjectType(cls))
 
       /** Create a function. */
       def function(exprs: Expression*): UnresolvedFunction =
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
index d450aecb732..e88b0e32e90 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala
@@ -95,7 +95,7 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite {
 
   test("attributes") {
     checkSQL($"a".int, "a")
-    checkSQL(Symbol("foo bar").int, "`foo bar`")
+    checkSQL(Symbol("`foo bar`").int, "`foo bar`")
     // Keyword
     checkSQL($"int".int, "int")
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
index cf8aea45583..a35fb5f6271 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala
@@ -27,18 +27,18 @@ import org.apache.spark.sql.types.{IntegerType, StringType, 
StructField, StructT
 class DataSourceStrategySuite extends PlanTest with SharedSparkSession {
   val attrInts = Seq(
     $"cint".int,
-    $"c.int".int,
+    $"`c.int`".int,
     GetStructField($"a".struct(StructType(
       StructField("cstr", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None),
     GetStructField($"a".struct(StructType(
       StructField("c.int", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 0, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("cstr1", StringType, nullable = true) ::
         StructField("cstr2", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None),
     GetStructField(GetStructField($"a".struct(StructType(
       StructField("cstr1", StringType, nullable = true) ::
@@ -56,18 +56,18 @@ class DataSourceStrategySuite extends PlanTest with 
SharedSparkSession {
 
   val attrStrs = Seq(
     $"cstr".string,
-    $"c.str".string,
+    $"`c.str`".string,
     GetStructField($"a".struct(StructType(
       StructField("cint", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
     GetStructField($"a".struct(StructType(
       StructField("c.str", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("cint1", IntegerType, nullable = true) ::
         StructField("cint2", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
     GetStructField(GetStructField($"a".struct(StructType(
       StructField("cint1", IntegerType, nullable = true) ::
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
index 8d6ffa30a72..3c4f5814375 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
@@ -29,18 +29,18 @@ import org.apache.spark.unsafe.types.UTF8String
 class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
   val attrInts = Seq(
     $"cint".int,
-    $"c.int".int,
+    $"`c.int`".int,
     GetStructField($"a".struct(StructType(
       StructField("cstr", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None),
     GetStructField($"a".struct(StructType(
       StructField("c.int", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 0, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("cstr1", StringType, nullable = true) ::
         StructField("cstr2", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None),
     GetStructField(GetStructField($"a".struct(StructType(
       StructField("cstr1", StringType, nullable = true) ::
@@ -58,18 +58,18 @@ class DataSourceV2StrategySuite extends PlanTest with 
SharedSparkSession {
 
   val attrStrs = Seq(
     $"cstr".string,
-    $"c.str".string,
+    $"`c.str`".string,
     GetStructField($"a".struct(StructType(
       StructField("cint", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
     GetStructField($"a".struct(StructType(
       StructField("c.str", StringType, nullable = true) ::
         StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("cint1", IntegerType, nullable = true) ::
         StructField("cint2", IntegerType, nullable = true) ::
         StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
-    GetStructField($"a.b".struct(StructType(
+    GetStructField($"`a.b`".struct(StructType(
       StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
     GetStructField(GetStructField($"a".struct(StructType(
       StructField("cint1", IntegerType, nullable = true) ::


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to