Repository: spark
Updated Branches:
  refs/heads/master 4c74ee8d8 -> 0b8d69499


[SPARK-15764][SQL] Replace N^2 loop in BindReferences

BindReferences contains a n^2 loop which causes performance issues when 
operating over large schemas: to determine the ordinal of an attribute 
reference, we perform a linear scan over the `input` array. Because input can 
sometimes be a `List`, the call to `input(ordinal).nullable` can also be O(n).

Instead of performing a linear scan, we can convert the input into an array and 
build a hash map to map from expression ids to ordinals. The greater up-front 
cost of the map construction is offset by the fact that an expression can 
contain multiple attribute references, so the cost of the map construction is 
amortized across a number of lookups.

Perf. benchmarks to follow. /cc ericl

Author: Josh Rosen <joshro...@databricks.com>

Closes #13505 from JoshRosen/bind-references-improvement.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0b8d6949
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0b8d6949
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0b8d6949

Branch: refs/heads/master
Commit: 0b8d694999b43ada4833388cad6c285c7757cbf7
Parents: 4c74ee8
Author: Josh Rosen <joshro...@databricks.com>
Authored: Mon Jun 6 11:44:51 2016 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Mon Jun 6 11:44:51 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/AttributeMap.scala |  7 ----
 .../catalyst/expressions/BoundAttribute.scala   |  6 ++--
 .../sql/catalyst/expressions/package.scala      | 34 +++++++++++++++++++-
 .../spark/sql/catalyst/plans/QueryPlan.scala    |  2 +-
 .../execution/aggregate/HashAggregateExec.scala |  2 +-
 .../columnar/InMemoryTableScanExec.scala        |  4 +--
 6 files changed, 40 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index ef3cc55..96a11e3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -26,13 +26,6 @@ object AttributeMap {
   def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = {
     new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap)
   }
-
-  /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to 
ordinal */
-  def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = 
apply(schema.zipWithIndex)
-
-  /** Given a schema, constructs a map from ordinal to Attribute. */
-  def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] =
-    schema.zipWithIndex.map { case (a, i) => i -> a }.toMap
 }
 
 class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)])

http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index a38f1ec..7d16118 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -82,16 +82,16 @@ object BindReferences extends Logging {
 
   def bindReference[A <: Expression](
       expression: A,
-      input: Seq[Attribute],
+      input: AttributeSeq,
       allowFailures: Boolean = false): A = {
     expression.transform { case a: AttributeReference =>
       attachTree(a, "Binding attribute") {
-        val ordinal = input.indexWhere(_.exprId == a.exprId)
+        val ordinal = input.indexOf(a.exprId)
         if (ordinal == -1) {
           if (allowFailures) {
             a
           } else {
-            sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}")
+            sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", 
"]")}")
           }
         } else {
           BoundReference(ordinal, a.dataType, input(ordinal).nullable)

http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 23baa6f..81f5bb4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst
 
+import com.google.common.collect.Maps
+
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types.{StructField, StructType}
 
@@ -86,11 +88,41 @@ package object expressions  {
   /**
    * Helper functions for working with `Seq[Attribute]`.
    */
-  implicit class AttributeSeq(attrs: Seq[Attribute]) {
+  implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable {
     /** Creates a StructType with a schema matching this `Seq[Attribute]`. */
     def toStructType: StructType = {
       StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
     }
+
+    // It's possible that `attrs` is a linked list, which can lead to bad 
O(n^2) loops when
+    // accessing attributes by their ordinals. To avoid this performance 
penalty, convert the input
+    // to an array.
+    @transient private lazy val attrsArray = attrs.toArray
+
+    @transient private lazy val exprIdToOrdinal = {
+      val arr = attrsArray
+      val map = Maps.newHashMapWithExpectedSize[ExprId, Int](arr.length)
+      // Iterate over the array in reverse order so that the final map value 
is the first attribute
+      // with a given expression id.
+      var index = arr.length - 1
+      while (index >= 0) {
+        map.put(arr(index).exprId, index)
+        index -= 1
+      }
+      map
+    }
+
+    /**
+     * Returns the attribute at the given index.
+     */
+    def apply(ordinal: Int): Attribute = attrsArray(ordinal)
+
+    /**
+     * Returns the index of first attribute with a matching expression id, or 
-1 if no match exists.
+     */
+    def indexOf(exprId: ExprId): Int = {
+      Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 3de15a9..19a66cf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -296,7 +296,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] 
extends TreeNode[PlanT
   /**
    * All the attributes that are used for this plan.
    */
-  lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
+  lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
 
   private def cleanExpression(e: Expression): Expression = e match {
     case a: Alias =>

http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index f5bc062..f270ca0 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -49,7 +49,7 @@ case class HashAggregateExec(
 
   require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
 
-  override lazy val allAttributes: Seq[Attribute] =
+  override lazy val allAttributes: AttributeSeq =
     child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
       
aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0b8d6949/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index bd55e1a..a1c2f0a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -310,7 +310,7 @@ private[sql] case class InMemoryTableScanExec(
     // within the map Partitions closure.
     val schema = relation.partitionStatistics.schema
     val schemaIndex = schema.zipWithIndex
-    val relOutput = relation.output
+    val relOutput: AttributeSeq = relation.output
     val buffers = relation.cachedColumnBuffers
 
     buffers.mapPartitionsInternal { cachedBatchIterator =>
@@ -321,7 +321,7 @@ private[sql] case class InMemoryTableScanExec(
       // Find the ordinals and data types of the requested columns.
       val (requestedColumnIndices, requestedColumnDataTypes) =
         attributes.map { a =>
-          relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
+          relOutput.indexOf(a.exprId) -> a.dataType
         }.unzip
 
       // Do partition batch pruning if enabled


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to