Repository: spark Updated Branches: refs/heads/master 1a8b2a17d -> a783a8ed4
[SPARK-12320][SQL] throw exception if the number of fields does not line up for Tuple encoder Author: Wenchen Fan <wenc...@databricks.com> Closes #10293 from cloud-fan/err-msg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a783a8ed Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a783a8ed Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a783a8ed Branch: refs/heads/master Commit: a783a8ed49814a09fde653433a3d6de398ddf888 Parents: 1a8b2a1 Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Dec 16 13:18:56 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Dec 16 13:20:12 2015 -0800 ---------------------------------------------------------------------- .../apache/spark/sql/catalyst/dsl/package.scala | 3 +- .../catalyst/encoders/ExpressionEncoder.scala | 36 +++++++++++- .../expressions/complexTypeExtractors.scala | 10 ++-- .../encoders/EncoderResolutionSuite.scala | 60 +++++++++++++++++--- .../catalyst/expressions/ComplexTypeSuite.scala | 2 +- 5 files changed, 93 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index e509711..8102c93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -227,9 +227,10 @@ package object dsl { AttributeReference(s, mapType, nullable = true)() /** Creates a new AttributeReference of type struct */ - def struct(fields: StructField*): AttributeReference = struct(StructType(fields)) def struct(structType: StructType): AttributeReference = AttributeReference(s, structType, nullable = true)() + def struct(attrs: AttributeReference*): AttributeReference = + struct(StructType.fromAttributes(attrs)) } implicit class DslAttribute(a: AttributeReference) { http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 363178b..7a4401c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -244,9 +244,41 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(schema) + def fail(st: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + + " - Target schema: " + this.schema.simpleString) + } + + var maxOrdinal = -1 + fromRowExpression.foreach { + case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal + case _ => + } + if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { + fail(StructType.fromAttributes(schema), maxOrdinal) + } + val unbound = fromRowExpression transform { - case b: BoundReference => positionToAttribute(b.ordinal) + case b: BoundReference => schema(b.ordinal) + } + + val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] + unbound.foreach { + case g: GetStructField => + val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) + if (maxOrdinal < g.ordinal) { + exprToMaxOrdinal.update(g.child, g.ordinal) + } + case _ => + } + exprToMaxOrdinal.foreach { + case (expr, maxOrdinal) => + val schema = expr.dataType.asInstanceOf[StructType] + if (maxOrdinal != schema.length - 1) { + fail(schema, maxOrdinal) + } } val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 10ce10a..58f6a7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,14 +104,14 @@ object ExtractValue { case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) extends UnaryExpression { - private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) + private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] - override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable - override def toString: String = s"$child.${name.getOrElse(field.name)}" + override def dataType: DataType = childSchema(ordinal).dataType + override def nullable: Boolean = child.nullable || childSchema(ordinal).nullable + override def toString: String = s"$child.${name.getOrElse(childSchema(ordinal).name)}" protected override def nullSafeEval(input: Any): Any = - input.asInstanceOf[InternalRow].get(ordinal, field.dataType) + input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 0289988..815a03f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -64,22 +64,21 @@ class EncoderResolutionSuite extends PlanTest { val innerCls = classOf[StringLongClass] val cls = classOf[ComplexClass] - val structType = new StructType().add("a", IntegerType).add("b", LongType) - val attrs = Seq('a.int, 'b.struct(structType)) + val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, Seq( 'a.int.cast(LongType), If( - 'b.struct(structType).isNull, + 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), NewInstance( innerCls, Seq( toExternalString( - GetStructField('b.struct(structType), 0, Some("a")).cast(StringType)), - GetStructField('b.struct(structType), 1, Some("b"))), + GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), + GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), false, ObjectType(innerCls)) )), @@ -94,8 +93,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[Long]) val cls = classOf[StringLongClass] - val structType = new StructType().add("a", StringType).add("b", ByteType) - val attrs = Seq('a.struct(structType), 'b.int) + val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( classOf[Tuple2[_, _]], @@ -103,8 +101,8 @@ class EncoderResolutionSuite extends PlanTest { NewInstance( cls, Seq( - toExternalString(GetStructField('a.struct(structType), 0, Some("a"))), - GetStructField('a.struct(structType), 1, Some("b")).cast(LongType)), + toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), false, ObjectType(cls)), 'b.int.cast(LongType)), @@ -113,6 +111,50 @@ class EncoderResolutionSuite extends PlanTest { compareExpressions(fromRowExpr, expected) } + test("the real number of fields doesn't match encoder schema: tuple encoder") { + val encoder = ExpressionEncoder[(String, Long)] + + { + val attrs = Seq('a.string, 'b.long, 'c.int) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:bigint,c:int>\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + + { + val attrs = Seq('a.string) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<a:string> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string>\n" + + " - Target schema: struct<_1:string,_2:bigint>") + } + } + + test("the real number of fields doesn't match encoder schema: nested tuple encoder") { + val encoder = ExpressionEncoder[(String, (Long, String))] + + { + val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + + { + val attrs = Seq('a.string, 'b.struct('x.long)) + assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + "Try to map struct<x:bigint> to Tuple2, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:struct<x:bigint>>\n" + + " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + } + } + private def toExternalString(e: Expression): Expression = { Invoke(e, "toString", ObjectType(classOf[String]), Nil) } http://git-wip-us.apache.org/repos/asf/spark/blob/a783a8ed/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 62fd472..9f1b192 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -165,7 +165,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { "b", create_row(Map("a" -> "b"))) checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), 1, create_row(create_row(1))) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org