Repository: spark
Updated Branches:
  refs/heads/master d728d5c98 -> 67e23b39a


[SPARK-10429] [SQL] make mutableProjection atomic

Right now, SQL's mutable projection updates every value of the mutable project 
after it evaluates the corresponding expression. This makes the behavior of 
MutableProjection confusing and complicate the implementation of common 
aggregate functions like stddev because developers need to be aware that when 
evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the 
mutable row has already been updated.

This PR make the MutableProjection atomic, by generating all the results of 
expressions first, then copy them into mutableRow.

Had run a mircro-benchmark, there is no notable performance difference between 
using class members and local variables.

cc yhuai

Author: Davies Liu <[email protected]>

Closes #9422 from davies/atomic_mutable and squashes the following commits:

bbc1758 [Davies Liu] support wide table
8a0ae14 [Davies Liu] fix bug
bec07da [Davies Liu] refactor
2891628 [Davies Liu] make mutableProjection atomic


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67e23b39
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67e23b39
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67e23b39

Branch: refs/heads/master
Commit: 67e23b39ac3cdee06668fa9131951278b9731e29
Parents: d728d5c
Author: Davies Liu <[email protected]>
Authored: Tue Nov 3 11:42:08 2015 +0100
Committer: Michael Armbrust <[email protected]>
Committed: Tue Nov 3 11:42:08 2015 +0100

----------------------------------------------------------------------
 .../sql/catalyst/expressions/Projection.scala   |  13 +-
 .../expressions/aggregate/functions.scala       | 154 ++++++++-----------
 .../codegen/GenerateMutableProjection.scala     |  28 +++-
 3 files changed, 97 insertions(+), 98 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index afe52e6..a6fe730 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, 
GenerateUnsafeProjection}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * A [[Projection]] that is calculated by calling the `eval` of each of the 
specified expressions.
@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: 
Seq[Expression]) extends Mu
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
+  private[this] val buffer = new Array[Any](expressions.size)
+
   expressions.foreach(_.foreach {
     case n: Nondeterministic => n.setInitialValues()
     case _ =>
@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: 
Seq[Expression]) extends Mu
   override def apply(input: InternalRow): InternalRow = {
     var i = 0
     while (i < exprArray.length) {
-      mutableRow(i) = exprArray(i).eval(input)
+      // Store the result into buffer first, to make the projection atomic 
(needed by aggregation)
+      buffer(i) = exprArray(i).eval(input)
+      i += 1
+    }
+    i = 0
+    while (i < exprArray.length) {
+      mutableRow(i) = buffer(i)
       i += 1
     }
     mutableRow

http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 5d2eb7b..f2c3eca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -57,37 +57,37 @@ case class Average(child: Expression) extends 
DeclarativeAggregate {
     case _ => DoubleType
   }
 
-  private val currentSum = AttributeReference("currentSum", sumDataType)()
-  private val currentCount = AttributeReference("currentCount", LongType)()
+  private val sum = AttributeReference("sum", sumDataType)()
+  private val count = AttributeReference("count", LongType)()
 
-  override val aggBufferAttributes = currentSum :: currentCount :: Nil
+  override val aggBufferAttributes = sum :: count :: Nil
 
   override val initialValues = Seq(
-    /* currentSum = */ Cast(Literal(0), sumDataType),
-    /* currentCount = */ Literal(0L)
+    /* sum = */ Cast(Literal(0), sumDataType),
+    /* count = */ Literal(0L)
   )
 
   override val updateExpressions = Seq(
-    /* currentSum = */
+    /* sum = */
     Add(
-      currentSum,
+      sum,
       Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: 
Nil)),
-    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+    /* count = */ If(IsNull(child), count, count + 1L)
   )
 
   override val mergeExpressions = Seq(
-    /* currentSum = */ currentSum.left + currentSum.right,
-    /* currentCount = */ currentCount.left + currentCount.right
+    /* sum = */ sum.left + sum.right,
+    /* count = */ count.left + count.right
   )
 
-  // If all input are nulls, currentCount will be 0 and we will get null after 
the division.
+  // If all input are nulls, count will be 0 and we will get null after the 
division.
   override val evaluateExpression = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       // increase the precision and scale to prevent precision loss
       val dt = DecimalType.bounded(p + 14, s + 4)
-      Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
+      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
     case _ =>
-      Cast(currentSum, resultType) / Cast(currentCount, resultType)
+      Cast(sum, resultType) / Cast(count, resultType)
   }
 }
 
@@ -102,23 +102,23 @@ case class Count(child: Expression) extends 
DeclarativeAggregate {
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  private val currentCount = AttributeReference("currentCount", LongType)()
+  private val count = AttributeReference("count", LongType)()
 
-  override val aggBufferAttributes = currentCount :: Nil
+  override val aggBufferAttributes = count :: Nil
 
   override val initialValues = Seq(
-    /* currentCount = */ Literal(0L)
+    /* count = */ Literal(0L)
   )
 
   override val updateExpressions = Seq(
-    /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+    /* count = */ If(IsNull(child), count, count + 1L)
   )
 
   override val mergeExpressions = Seq(
-    /* currentCount = */ currentCount.left + currentCount.right
+    /* count = */ count.left + count.right
   )
 
-  override val evaluateExpression = Cast(currentCount, LongType)
+  override val evaluateExpression = Cast(count, LongType)
 }
 
 /**
@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends 
DeclarativeAggregate {
 
   private val resultType = DoubleType
 
-  private val preCount = AttributeReference("preCount", resultType)()
-  private val currentCount = AttributeReference("currentCount", resultType)()
-  private val preAvg = AttributeReference("preAvg", resultType)()
-  private val currentAvg = AttributeReference("currentAvg", resultType)()
-  private val currentMk = AttributeReference("currentMk", resultType)()
+  private val count = AttributeReference("count", resultType)()
+  private val avg = AttributeReference("avg", resultType)()
+  private val mk = AttributeReference("mk", resultType)()
 
-  override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
-                                  currentAvg :: currentMk :: Nil
+  override val aggBufferAttributes = count :: avg :: mk :: Nil
 
   override val initialValues = Seq(
-    /* preCount = */ Cast(Literal(0), resultType),
-    /* currentCount = */ Cast(Literal(0), resultType),
-    /* preAvg = */ Cast(Literal(0), resultType),
-    /* currentAvg = */ Cast(Literal(0), resultType),
-    /* currentMk = */ Cast(Literal(0), resultType)
+    /* count = */ Cast(Literal(0), resultType),
+    /* avg = */ Cast(Literal(0), resultType),
+    /* mk = */ Cast(Literal(0), resultType)
   )
 
   override val updateExpressions = {
+    val value = Cast(child, resultType)
+    val newCount = count + Cast(Literal(1), resultType)
 
     // update average
     // avg = avg + (value - avg)/count
-    def avgAdd: Expression = {
-      currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
-    }
+    val newAvg = avg + (value - avg) / newCount
 
     // update sum of square of difference from mean
     // Mk = Mk + (value - preAvg) * (value - updatedAvg)
-    def mkAdd: Expression = {
-      val delta1 = Cast(child, resultType) - preAvg
-      val delta2 = Cast(child, resultType) - currentAvg
-      currentMk + (delta1 * delta2)
-    }
+    val newMk = mk + (value - avg) * (value - newAvg)
 
     Seq(
-      /* preCount = */ If(IsNull(child), preCount, currentCount),
-      /* currentCount = */ If(IsNull(child), currentCount,
-                           Add(currentCount, Cast(Literal(1), resultType))),
-      /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
-      /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
-      /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+      /* count = */ If(IsNull(child), count, newCount),
+      /* avg = */ If(IsNull(child), avg, newAvg),
+      /* mk = */ If(IsNull(child), mk, newMk)
     )
   }
 
   override val mergeExpressions = {
 
     // count merge
-    def countMerge: Expression = {
-      currentCount.left + currentCount.right
-    }
+    val newCount = count.left + count.right
 
     // average merge
-    def avgMerge: Expression = {
-      ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) 
/
-      (preCount + currentCount.right)
-    }
+    val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / 
newCount
 
     // update sum of square differences
-    def mkMerge: Expression = {
-      val avgDelta = currentAvg.right - preAvg
-      val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
-        (preCount + currentCount.right)
-
-      currentMk.left + currentMk.right + mkDelta
+    val newMk = {
+      val avgDelta = avg.right - avg.left
+      val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / 
newCount
+      mk.left + mk.right + mkDelta
     }
 
     Seq(
-      /* preCount = */ If(IsNull(currentCount.left),
-                         Cast(Literal(0), resultType), currentCount.left),
-      /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
-                             If(IsNull(currentCount.right), currentCount.left, 
countMerge)),
-      /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), 
currentAvg.left),
-      /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
-                           If(IsNull(currentAvg.right), currentAvg.left, 
avgMerge)),
-      /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
-                          If(IsNull(currentMk.right), currentMk.left, mkMerge))
+      /* count = */ If(IsNull(count.left), count.right,
+                       If(IsNull(count.right), count.left, newCount)),
+      /* avg = */ If(IsNull(avg.left), avg.right,
+                     If(IsNull(avg.right), avg.left, newAvg)),
+      /* mk = */ If(IsNull(mk.left), mk.right,
+                    If(IsNull(mk.right), mk.left, newMk))
     )
   }
 
   override val evaluateExpression = {
-    // when currentCount == 0, return null
-    // when currentCount == 1, return 0
-    // when currentCount >1
-    // stddev_samp = sqrt (currentMk/(currentCount -1))
-    // stddev_pop = sqrt (currentMk/currentCount)
-    val varCol = {
+    // when count == 0, return null
+    // when count == 1, return 0
+    // when count >1
+    // stddev_samp = sqrt (mk/(count -1))
+    // stddev_pop = sqrt (mk/count)
+    val varCol =
       if (isSample) {
-        currentMk / Cast((currentCount - Cast(Literal(1), resultType)), 
resultType)
-      }
-      else {
-        currentMk / currentCount
+        mk / Cast((count - Cast(Literal(1), resultType)), resultType)
+      } else {
+        mk / count
       }
-    }
 
-    If(EqualTo(currentCount, Cast(Literal(0), resultType)), 
Cast(Literal(null), resultType),
-      If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), 
resultType),
+    If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), 
resultType),
+      If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), 
resultType),
         Cast(Sqrt(varCol), resultType)))
   }
 }
@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
 
   private val sumDataType = resultType
 
-  private val currentSum = AttributeReference("currentSum", sumDataType)()
+  private val sum = AttributeReference("sum", sumDataType)()
 
   private val zero = Cast(Literal(0), sumDataType)
 
-  override val aggBufferAttributes = currentSum :: Nil
+  override val aggBufferAttributes = sum :: Nil
 
   override val initialValues = Seq(
-    /* currentSum = */ Literal.create(null, sumDataType)
+    /* sum = */ Literal.create(null, sumDataType)
   )
 
   override val updateExpressions = Seq(
-    /* currentSum = */
-    Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, 
sumDataType)), currentSum))
+    /* sum = */
+    Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
   )
 
   override val mergeExpressions = {
-    val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, 
sumDataType))
+    val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
     Seq(
-      /* currentSum = */
-      Coalesce(Seq(add, currentSum.left))
+      /* sum = */
+      Coalesce(Seq(add, sum.left))
     )
   }
 
-  override val evaluateExpression = Cast(currentSum, resultType)
+  override val evaluateExpression = Cast(sum, resultType)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e8ee647..4b66069 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,28 +44,42 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
       case (NoOp, _) => ""
       case (e, i) =>
         val evaluationCode = e.gen(ctx)
+        val isNull = s"isNull_$i"
+        val value = s"value_$i"
+        ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+        ctx.addMutableState(ctx.javaType(e.dataType), value,
+          s"this.$value = ${ctx.defaultValue(e.dataType)};")
+        s"""
+          ${evaluationCode.code}
+          this.$isNull = ${evaluationCode.isNull};
+          this.$value = ${evaluationCode.value};
+         """
+    }
+    val updates = expressions.zipWithIndex.map {
+      case (NoOp, _) => ""
+      case (e, i) =>
         if (e.dataType.isInstanceOf[DecimalType]) {
           // Can't call setNullAt on DecimalType, because we need to keep the 
offset
           s"""
-            ${evaluationCode.code}
-            if (${evaluationCode.isNull}) {
+            if (this.isNull_$i) {
               ${ctx.setColumn("mutableRow", e.dataType, i, null)};
             } else {
-              ${ctx.setColumn("mutableRow", e.dataType, i, 
evaluationCode.value)};
+              ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
             }
           """
         } else {
           s"""
-            ${evaluationCode.code}
-            if (${evaluationCode.isNull}) {
+            if (this.isNull_$i) {
               mutableRow.setNullAt($i);
             } else {
-              ${ctx.setColumn("mutableRow", e.dataType, i, 
evaluationCode.value)};
+              ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
             }
           """
         }
     }
+
     val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
+    val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
 
     val code = s"""
       public Object generate($exprType[] expr) {
@@ -98,6 +112,8 @@ object GenerateMutableProjection extends 
CodeGenerator[Seq[Expression], () => Mu
         public Object apply(Object _i) {
           InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
           $allProjections
+          // copy all the results into MutableRow
+          $allUpdates
           return mutableRow;
         }
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to