Repository: spark
Updated Branches:
  refs/heads/branch-2.2 8acd02f42 -> f73637798


[SPARK-22442][SQL][BRANCH-2.2] ScalaReflection should produce correct field 
names for special characters

## What changes were proposed in this pull request?

For a class with field name of special characters, e.g.:
```scala
case class MyType(`field.1`: String, `field 2`: String)
```

Although we can manipulate DataFrame/Dataset, the field names are encoded:
```scala
scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF
df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: 
string]
scala> df.as[MyType].collect
res7: Array[MyType] = Array(MyType(a,b), MyType(c,d))
```

It causes resolving problem when we try to convert the data with non-encoded 
field names:
```scala
spark.read.json(path).as[MyType]
...
[info]   org.apache.spark.sql.AnalysisException: cannot resolve 
'`field$u002E1`' given input columns: [field 2, fie
ld.1];
[info]   at 
org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
...
```

We should use decoded field name in Dataset schema.

## How was this patch tested?

Added tests.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #19734 from viirya/SPARK-22442-2.2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7363779
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7363779
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7363779

Branch: refs/heads/branch-2.2
Commit: f736377980fd6cd46346512a68482a88aa6a1711
Parents: 8acd02f
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Sun Nov 12 21:19:15 2017 -0800
Committer: Felix Cheung <felixche...@apache.org>
Committed: Sun Nov 12 21:19:15 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala     |  9 +++++----
 .../catalyst/expressions/objects/objects.scala   | 11 +++++++----
 .../sql/catalyst/ScalaReflectionSuite.scala      | 19 ++++++++++++++++++-
 .../org/apache/spark/sql/DatasetSuite.scala      | 12 ++++++++++++
 4 files changed, 42 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index ad21842..7f72751 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -151,7 +151,7 @@ object ScalaReflection extends ScalaReflection {
     def addToPath(part: String, dataType: DataType, walkedTypePath: 
Seq[String]): Expression = {
       val newPath = path
         .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-        .getOrElse(UnresolvedAttribute(part))
+        .getOrElse(UnresolvedAttribute.quoted(part))
       upCastToExpectedType(newPath, dataType, walkedTypePath)
     }
 
@@ -671,7 +671,7 @@ object ScalaReflection extends ScalaReflection {
     val m = runtimeMirror(cls.getClassLoader)
     val classSymbol = m.staticClass(cls.getName)
     val t = classSymbol.selfType
-    constructParams(t).map(_.name.toString)
+    constructParams(t).map(_.name.decodedName.toString)
   }
 
   /**
@@ -861,11 +861,12 @@ trait ScalaReflection {
     // if there are type variables to fill in, do the substitution 
(SomeClass[T] -> SomeClass[Int])
     if (actualTypeArgs.nonEmpty) {
       params.map { p =>
-        p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, 
actualTypeArgs)
+        p.name.decodedName.toString ->
+          p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
       }
     } else {
       params.map { p =>
-        p.name.toString -> p.typeSignature
+        p.name.decodedName.toString -> p.typeSignature
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 43cef6c..0b45dfe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
@@ -189,11 +190,13 @@ case class Invoke(
   override def eval(input: InternalRow): Any =
     throw new UnsupportedOperationException("Only code-generated evaluation is 
supported.")
 
+  private lazy val encodedFunctionName = 
TermName(functionName).encodedName.toString
+
   @transient lazy val method = targetObject.dataType match {
     case ObjectType(cls) =>
-      val m = cls.getMethods.find(_.getName == functionName)
+      val m = cls.getMethods.find(_.getName == encodedFunctionName)
       if (m.isEmpty) {
-        sys.error(s"Couldn't find $functionName on $cls")
+        sys.error(s"Couldn't find $encodedFunctionName on $cls")
       } else {
         m
       }
@@ -222,7 +225,7 @@ case class Invoke(
     }
 
     val evaluate = if (returnPrimitive) {
-      getFuncResult(ev.value, s"${obj.value}.$functionName($argString)")
+      getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)")
     } else {
       val funcResult = ctx.freshName("funcResult")
       // If the function can return null, we do an extra check to make sure 
our null bit is still
@@ -240,7 +243,7 @@ case class Invoke(
       }
       s"""
         Object $funcResult = null;
-        ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")}
+        ${getFuncResult(funcResult, 
s"${obj.value}.$encodedFunctionName($argString)")}
         $assignResult
       """
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 35683ef..c0e4e37 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -23,7 +23,8 @@ import java.sql.{Date, Timestamp}
 import scala.reflect.runtime.universe.typeOf
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, 
SpecificInternalRow}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, 
SpecificInternalRow, UpCast}
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
NewInstance}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -83,6 +84,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: 
Double) {
   def this(b: String, a: Int) = this(a, b, c = 1.0)
 }
 
+case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
+
 object TestingUDT {
   @SQLUserDefinedType(udt = classOf[NestedStructUDT])
   class NestedStruct(val a: Integer, val b: Long, val c: Double)
@@ -354,4 +357,18 @@ class ScalaReflectionSuite extends SparkFunSuite {
     assert(deserializerFor[Int].isInstanceOf[AssertNotNull])
     assert(!deserializerFor[String].isInstanceOf[AssertNotNull])
   }
+
+  test("SPARK-22442: Generate correct field names for special characters") {
+    val serializer = serializerFor[SpecialCharAsFieldData](BoundReference(
+      0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false))
+    val deserializer = deserializerFor[SpecialCharAsFieldData]
+    assert(serializer.dataType(0).name == "field.1")
+    assert(serializer.dataType(1).name == "field 2")
+
+    val argumentsFields = 
deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect {
+      case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts
+    }}
+    assert(argumentsFields(0) == Seq("field.1"))
+    assert(argumentsFields(1) == Seq("field 2"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 88a4167..683fe4a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1228,6 +1228,16 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
       assert(e.getCause.isInstanceOf[NullPointerException])
     }
   }
+
+  test("SPARK-22442: Generate correct field names for special characters") {
+    withTempPath { dir =>
+      val path = dir.getCanonicalPath
+      val data = """{"field.1": 1, "field 2": 2}"""
+      Seq(data).toDF().repartition(1).write.text(path)
+      val ds = spark.read.json(path).as[SpecialCharClass]
+      checkDataset(ds, SpecialCharClass("1", "2"))
+    }
+  }
 }
 
 case class WithImmutableMap(id: String, map_test: 
scala.collection.immutable.Map[Long, String])
@@ -1313,3 +1323,5 @@ case class CircularReferenceClassB(cls: 
CircularReferenceClassA)
 case class CircularReferenceClassC(ar: Array[CircularReferenceClassC])
 case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
 case class CircularReferenceClassE(id: String, list: 
List[CircularReferenceClassD])
+
+case class SpecialCharClass(`field.1`: String, `field 2`: String)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to