diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ea8c369ee49ed..7ae5924b20faf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -86,4 +86,13 @@ object BindReferences extends Logging { } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. } + + /** + * A helper function to bind given expressions to an input schema. + */ + def bindReferences[A <: Expression]( + expressions: Seq[A], + input: AttributeSeq): Seq[A] = { + expressions.map(BindReferences.bindReference(_, input)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala index 122a564da61be..5c8aa4e2e9d83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp @@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp */ class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(toBoundExprs(expressions, inputSchema)) + this(bindReferences(expressions, inputSchema)) private[this] val buffer = new Array[Any](expressions.size) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index b48f7ba655b2f..eaaf94baac216 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType} */ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = - this(expressions.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(expressions, inputSchema)) override def initialize(partitionIndex: Int): Unit = { expressions.foreach(_.foreach { @@ -99,7 +100,7 @@ object MutableProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -162,7 +163,7 @@ object UnsafeProjection * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } @@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio * `inputSchema`. */ def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { - create(toBoundExprs(exprs, inputSchema)) + create(bindReferences(exprs, inputSchema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index d588e7f081303..838bd1c679e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp // MutableProjection is not accessible in Java @@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 283fd2a6e9383..b66b80ad31dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) /** * Creates a code gen ordering for sorting this schema, in ascending order. @@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder]) extends Ordering[InternalRow] with KryoSerializable { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(ordering, inputSchema)) @transient private[this] var generatedOrdering = GenerateOrdering.generate(ordering) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 39778661d1c48..e285398ba1958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} @@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) private def createCodeForStruct( ctx: CodegenContext, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 0ecd0de8d8203..fb1d8a3c8e739 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro in.map(ExpressionCanonicalizer.execute) protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) + bindReferences(in, inputSchema) def generate( expressions: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index e24a3de3cfdbe..c8d667143f452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.types._ @@ -27,7 +28,7 @@ import org.apache.spark.sql.types._ class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = - this(ordering.map(BindReferences.bindReference(_, inputSchema))) + this(bindReferences(ordering, inputSchema)) def compare(a: InternalRow, b: InternalRow): Int = { var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index bf18e8bcb52df..932c364737249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -85,13 +85,6 @@ package object expressions { override def apply(row: InternalRow): InternalRow = row } - /** - * A helper function to bind given expressions to an input schema. - */ - def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = { - exprs.map(BindReferences.bindReference(_, inputSchema)) - } - /** * Helper functions for working with `Seq[Attribute]`. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala index 5b4edf5136e3f..85f49140a4b41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala @@ -145,11 +145,12 @@ case class ExpandExec( // Part 1: declare variables for each column // If a column has the same value for all output rows, then we also generate its computation // right after declaration. Otherwise its value is computed in the part 2. + lazy val attributeSeq: AttributeSeq = child.output val outputColumns = output.indices.map { col => val firstExpr = projections.head(col) if (sameOutput(col)) { // This column is the same across all output rows. Just generate code for it here. - BindReferences.bindReference(firstExpr, child.output).genCode(ctx) + BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx) } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -170,7 +171,7 @@ case class ExpandExec( var updateCode = "" for (col <- exprs.indices) { if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx) + val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx) updateCode += s""" |${ev.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 98c4a51299958..a1fb23d621d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -77,6 +77,7 @@ abstract class AggregationIterator( val expressionsLength = expressions.length val functions = new Array[AggregateFunction](expressionsLength) var i = 0 + val inputAttributeSeq: AttributeSeq = inputAttributes while (i < expressionsLength) { val func = expressions(i).aggregateFunction val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { @@ -86,7 +87,7 @@ abstract class AggregationIterator( // this function is Partial or Complete because we will call eval of this // function's children in the update method of this aggregate function. // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) + BindReferences.bindReference(func, inputAttributeSeq) case _ => // We only need to set inputBufferOffset for aggregate functions with mode // PartialMerge and Final. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4827f838fc514..28801774418a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -199,15 +200,13 @@ case class HashAggregateExec( val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { // evaluate aggregate results ctx.currentVars = bufVars - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences( + functions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) - } + val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx)) (resultVars, s""" |$evaluateAggResults |${evaluateVariables(resultVars)} @@ -264,7 +263,7 @@ case class HashAggregateExec( } } ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttrs) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -456,16 +455,16 @@ case class HashAggregateExec( val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars - val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } + val aggResults = bindReferences( + declFunctions.map(_.evaluateExpression), + aggregateBufferAttributes).map(_.genCode(ctx)) val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression]( + resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateBufferVars @@ -494,9 +493,9 @@ case class HashAggregateExec( ctx.currentVars = keyVars ++ resultBufferVars val inputAttrs = resultExpressions.map(_.toAttribute) - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } + val resultVars = bindReferences[Expression]( + resultExpressions, + inputAttrs).map(_.genCode(ctx)) s""" $evaluateKeyVars $evaluateResultBufferVars @@ -506,9 +505,9 @@ case class HashAggregateExec( // generate result based on grouping key ctx.INPUT_ROW = keyTerm ctx.currentVars = null - val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).genCode(ctx) - } + val eval = bindReferences[Expression]( + resultExpressions, + groupingAttributes).map(_.genCode(ctx)) consume(ctx, eval) } ctx.addNewFunction(funcName, @@ -730,9 +729,9 @@ case class HashAggregateExec( private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // create grouping key val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( - ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + ctx, bindReferences[Expression](groupingExpressions, child.output)) val fastRowKeys = ctx.generateExpressions( - groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + bindReferences[Expression](groupingExpressions, child.output)) val unsafeRowKeys = unsafeRowKeyCode.value val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") val fastRowBuffer = ctx.freshName("fastAggBuffer") @@ -825,7 +824,7 @@ case class HashAggregateExec( val updateRowInRegularHashMap: String = { ctx.INPUT_ROW = unsafeRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) { @@ -849,7 +848,7 @@ case class HashAggregateExec( if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { ctx.INPUT_ROW = fastRowBuffer - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) + val boundUpdateExpr = bindReferences(updateExpr, inputAttr) val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) val effectiveCodes = subExprs.codes.mkString("\n") val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 09effe087e195..7e87a150c0a4d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan) } override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { - val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output)) + val exprs = bindReferences[Expression](projectList, child.output) val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 774fe38f5c2e6..260ad97506a85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} @@ -145,9 +146,8 @@ object FileFormatWriter extends Logging { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = requiredOrdering - .map(SortOrder(_, Ascending)) - .map(BindReferences.bindReference(_, outputSpec.outputColumns)) + val orderingExpr = bindReferences( + requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) SortExec( orderingExpr, global = false, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1aef5f6864263..5ee4c7ffb1911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{RowIterator, SparkPlan} @@ -63,9 +64,8 @@ trait HashJoin { protected lazy val (buildKeys, streamedKeys) = { require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), "Join keys from two sides should have same types") - val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = HashJoin.rewriteKeyExpr(rightKeys) - .map(BindReferences.bindReference(_, right.output)) + val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output) + val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output) buildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d7d3f6d6078b4..f829f07e80720 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans._ @@ -393,7 +394,7 @@ case class SortMergeJoinExec( input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row ctx.currentVars = null - keys.map(BindReferences.bindReference(_, input).genCode(ctx)) + bindReferences(keys, input).map(_.genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala index 156002ef58fbe..5bf34558fe493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala @@ -21,6 +21,7 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray @@ -89,9 +90,8 @@ private[window] final class OffsetWindowFunctionFrame( private[this] val projection = { // Collect the expressions and bind them. val inputAttrs = inputSchema.map(_.withNullability(true)) - val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => - BindReferences.bindReference(e.input, inputAttrs) - } + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences( + expressions.toSeq.map(_.input), inputAttrs) // Create the projection. newMutableProjection(boundExpressions, Nil).target(target) @@ -100,7 +100,7 @@ private[window] final class OffsetWindowFunctionFrame( /** Create the projection used when the offset row DOES NOT exists. */ private[this] val fillDefaultValue = { // Collect the expressions and bind them. - val inputAttrs = inputSchema.map(_.withNullability(true)) + val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true)) val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e => if (e.default == null || e.default.foldable && e.default.eval() == null) { // The default value is null.
With regards, Apache Git Services --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org