Github user marmbrus commented on a diff in the pull request:
https://github.com/apache/spark/pull/9840#discussion_r45562684
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
---
@@ -124,17 +124,46 @@ object ScalaReflection extends ScalaReflection {
path: Option[Expression]): Expression =
ScalaReflectionLock.synchronized {
/** Returns the current path with a sub-field extracted. */
- def addToPath(part: String): Expression = path
- .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
- .getOrElse(UnresolvedAttribute(part))
+ def addToPath(part: String, dataType: DataType): Expression = {
+ val newPath = path
+ .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+ .getOrElse(UnresolvedAttribute(part))
+ castToExpectedType(newPath, dataType)
+ }
/** Returns the current path with a field at ordinal extracted. */
- def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression =
path
- .map(p => GetInternalRowField(p, ordinal, dataType))
- .getOrElse(BoundReference(ordinal, dataType, false))
+ def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = {
+ val newPath = path
+ .map(p => GetStructField(p, new StructField("", dataType),
ordinal))
+ .getOrElse(BoundReference(ordinal, dataType, false))
+ castToExpectedType(newPath, dataType)
+ }
/** Returns the current path or `BoundReference`. */
- def getPath: Expression = path.getOrElse(BoundReference(0,
schemaFor(tpe).dataType, true))
+ def getPath: Expression = {
+ val dataType = schemaFor(tpe).dataType
+ path.getOrElse(castToExpectedType(BoundReference(0, dataType, true),
dataType))
+ }
+
+ /**
+ * When we build the `fromRowExpression` for an encoder, we set up a
lot of "unresolved" stuff
+ * and lost the required data type, which may lead to runtime error if
the real type doesn't
+ * match the encoder's schema.
+ * For example, we build an encoder for `case class Data(a: Int, b:
String)` and the real type
+ * is [a: int, b: long], then we will hit runtime error and say that
we can't construct class
+ * `Data` with int and long, because we lost the information that `b`
should be a string.
+ *
+ * This method help us "remember" the require data type by adding a
`Cast`. Note that we don't
+ * need to add `Cast` for struct type because there must be
`UnresolvedExtractValue` or
+ * `GetStructField` wrapping it.
+ *
+ * TODO: this only works if the real type is compatible with the
encoder's schema, we should
+ * also handle error cases.
--- End diff --
I'm not sure if we want to automatically downcast where we could possibly
truncate the values. Unlike an explicit cast, where the user is asking for it,
I think this could be confusing. Consider the following:
```scala
scala> case class Data(value: Int)
scala> Seq(Int.MaxValue.toLong + 1).toDS().as[Data].collect()
res6: Array[Data] = Array(Data(-2147483648))
```
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]