Repository: spark Updated Branches: refs/heads/branch-2.2 8acd02f42 -> f73637798
[SPARK-22442][SQL][BRANCH-2.2] ScalaReflection should produce correct field names for special characters ## What changes were proposed in this pull request? For a class with field name of special characters, e.g.: ```scala case class MyType(`field.1`: String, `field 2`: String) ``` Although we can manipulate DataFrame/Dataset, the field names are encoded: ```scala scala> val df = Seq(MyType("a", "b"), MyType("c", "d")).toDF df: org.apache.spark.sql.DataFrame = [field$u002E1: string, field$u00202: string] scala> df.as[MyType].collect res7: Array[MyType] = Array(MyType(a,b), MyType(c,d)) ``` It causes resolving problem when we try to convert the data with non-encoded field names: ```scala spark.read.json(path).as[MyType] ... [info] org.apache.spark.sql.AnalysisException: cannot resolve '`field$u002E1`' given input columns: [field 2, fie ld.1]; [info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) ... ``` We should use decoded field name in Dataset schema. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #19734 from viirya/SPARK-22442-2.2. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f7363779 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f7363779 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f7363779 Branch: refs/heads/branch-2.2 Commit: f736377980fd6cd46346512a68482a88aa6a1711 Parents: 8acd02f Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Sun Nov 12 21:19:15 2017 -0800 Committer: Felix Cheung <felixche...@apache.org> Committed: Sun Nov 12 21:19:15 2017 -0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 9 +++++---- .../catalyst/expressions/objects/objects.scala | 11 +++++++---- .../sql/catalyst/ScalaReflectionSuite.scala | 19 ++++++++++++++++++- .../org/apache/spark/sql/DatasetSuite.scala | 12 ++++++++++++ 4 files changed, 42 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ad21842..7f72751 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -151,7 +151,7 @@ object ScalaReflection extends ScalaReflection { def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) - .getOrElse(UnresolvedAttribute(part)) + .getOrElse(UnresolvedAttribute.quoted(part)) upCastToExpectedType(newPath, dataType, walkedTypePath) } @@ -671,7 +671,7 @@ object ScalaReflection extends ScalaReflection { val m = runtimeMirror(cls.getClassLoader) val classSymbol = m.staticClass(cls.getName) val t = classSymbol.selfType - constructParams(t).map(_.name.toString) + constructParams(t).map(_.name.decodedName.toString) } /** @@ -861,11 +861,12 @@ trait ScalaReflection { // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) if (actualTypeArgs.nonEmpty) { params.map { p => - p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + p.name.decodedName.toString -> + p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) } } else { params.map { p => - p.name.toString -> p.typeSignature + p.name.decodedName.toString -> p.typeSignature } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 43cef6c..0b45dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -27,6 +27,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -189,11 +190,13 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + private lazy val encodedFunctionName = TermName(functionName).encodedName.toString + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - val m = cls.getMethods.find(_.getName == functionName) + val m = cls.getMethods.find(_.getName == encodedFunctionName) if (m.isEmpty) { - sys.error(s"Couldn't find $functionName on $cls") + sys.error(s"Couldn't find $encodedFunctionName on $cls") } else { m } @@ -222,7 +225,7 @@ case class Invoke( } val evaluate = if (returnPrimitive) { - getFuncResult(ev.value, s"${obj.value}.$functionName($argString)") + getFuncResult(ev.value, s"${obj.value}.$encodedFunctionName($argString)") } else { val funcResult = ctx.freshName("funcResult") // If the function can return null, we do an extra check to make sure our null bit is still @@ -240,7 +243,7 @@ case class Invoke( } s""" Object $funcResult = null; - ${getFuncResult(funcResult, s"${obj.value}.$functionName($argString)")} + ${getFuncResult(funcResult, s"${obj.value}.$encodedFunctionName($argString)")} $assignResult """ } http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 35683ef..c0e4e37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,8 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -83,6 +84,8 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String) + object TestingUDT { @SQLUserDefinedType(udt = classOf[NestedStructUDT]) class NestedStruct(val a: Integer, val b: Long, val c: Double) @@ -354,4 +357,18 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(deserializerFor[Int].isInstanceOf[AssertNotNull]) assert(!deserializerFor[String].isInstanceOf[AssertNotNull]) } + + test("SPARK-22442: Generate correct field names for special characters") { + val serializer = serializerFor[SpecialCharAsFieldData](BoundReference( + 0, ObjectType(classOf[SpecialCharAsFieldData]), nullable = false)) + val deserializer = deserializerFor[SpecialCharAsFieldData] + assert(serializer.dataType(0).name == "field.1") + assert(serializer.dataType(1).name == "field 2") + + val argumentsFields = deserializer.asInstanceOf[NewInstance].arguments.flatMap { _.collect { + case UpCast(u: UnresolvedAttribute, _, _) => u.nameParts + }} + assert(argumentsFields(0) == Seq("field.1")) + assert(argumentsFields(1) == Seq("field 2")) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/f7363779/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 88a4167..683fe4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1228,6 +1228,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getCause.isInstanceOf[NullPointerException]) } } + + test("SPARK-22442: Generate correct field names for special characters") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = """{"field.1": 1, "field 2": 2}""" + Seq(data).toDF().repartition(1).write.text(path) + val ds = spark.read.json(path).as[SpecialCharClass] + checkDataset(ds, SpecialCharClass("1", "2")) + } + } } case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) @@ -1313,3 +1323,5 @@ case class CircularReferenceClassB(cls: CircularReferenceClassA) case class CircularReferenceClassC(ar: Array[CircularReferenceClassC]) case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE]) case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD]) + +case class SpecialCharClass(`field.1`: String, `field 2`: String) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org