Repository: spark
Updated Branches:
refs/heads/master d728d5c98 -> 67e23b39a
[SPARK-10429] [SQL] make mutableProjection atomic
Right now, SQL's mutable projection updates every value of the mutable project
after it evaluates the corresponding expression. This makes the behavior of
MutableProjection confusing and complicate the implementation of common
aggregate functions like stddev because developers need to be aware that when
evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the
mutable row has already been updated.
This PR make the MutableProjection atomic, by generating all the results of
expressions first, then copy them into mutableRow.
Had run a mircro-benchmark, there is no notable performance difference between
using class members and local variables.
cc yhuai
Author: Davies Liu <[email protected]>
Closes #9422 from davies/atomic_mutable and squashes the following commits:
bbc1758 [Davies Liu] support wide table
8a0ae14 [Davies Liu] fix bug
bec07da [Davies Liu] refactor
2891628 [Davies Liu] make mutableProjection atomic
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67e23b39
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67e23b39
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67e23b39
Branch: refs/heads/master
Commit: 67e23b39ac3cdee06668fa9131951278b9731e29
Parents: d728d5c
Author: Davies Liu <[email protected]>
Authored: Tue Nov 3 11:42:08 2015 +0100
Committer: Michael Armbrust <[email protected]>
Committed: Tue Nov 3 11:42:08 2015 +0100
----------------------------------------------------------------------
.../sql/catalyst/expressions/Projection.scala | 13 +-
.../expressions/aggregate/functions.scala | 154 ++++++++-----------
.../codegen/GenerateMutableProjection.scala | 28 +++-
3 files changed, 97 insertions(+), 98 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index afe52e6..a6fe730 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection,
GenerateUnsafeProjection}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the
specified expressions.
@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions:
Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ private[this] val buffer = new Array[Any](expressions.size)
+
expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions:
Seq[Expression]) extends Mu
override def apply(input: InternalRow): InternalRow = {
var i = 0
while (i < exprArray.length) {
- mutableRow(i) = exprArray(i).eval(input)
+ // Store the result into buffer first, to make the projection atomic
(needed by aggregation)
+ buffer(i) = exprArray(i).eval(input)
+ i += 1
+ }
+ i = 0
+ while (i < exprArray.length) {
+ mutableRow(i) = buffer(i)
i += 1
}
mutableRow
http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 5d2eb7b..f2c3eca 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -57,37 +57,37 @@ case class Average(child: Expression) extends
DeclarativeAggregate {
case _ => DoubleType
}
- private val currentSum = AttributeReference("currentSum", sumDataType)()
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val sum = AttributeReference("sum", sumDataType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentSum :: currentCount :: Nil
+ override val aggBufferAttributes = sum :: count :: Nil
override val initialValues = Seq(
- /* currentSum = */ Cast(Literal(0), sumDataType),
- /* currentCount = */ Literal(0L)
+ /* sum = */ Cast(Literal(0), sumDataType),
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentSum = */
+ /* sum = */
Add(
- currentSum,
+ sum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) ::
Nil)),
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentSum = */ currentSum.left + currentSum.right,
- /* currentCount = */ currentCount.left + currentCount.right
+ /* sum = */ sum.left + sum.right,
+ /* count = */ count.left + count.right
)
- // If all input are nulls, currentCount will be 0 and we will get null after
the division.
+ // If all input are nulls, count will be 0 and we will get null after the
division.
override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
- Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
+ Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ =>
- Cast(currentSum, resultType) / Cast(currentCount, resultType)
+ Cast(sum, resultType) / Cast(count, resultType)
}
}
@@ -102,23 +102,23 @@ case class Count(child: Expression) extends
DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentCount :: Nil
+ override val aggBufferAttributes = count :: Nil
override val initialValues = Seq(
- /* currentCount = */ Literal(0L)
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentCount = */ currentCount.left + currentCount.right
+ /* count = */ count.left + count.right
)
- override val evaluateExpression = Cast(currentCount, LongType)
+ override val evaluateExpression = Cast(count, LongType)
}
/**
@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends
DeclarativeAggregate {
private val resultType = DoubleType
- private val preCount = AttributeReference("preCount", resultType)()
- private val currentCount = AttributeReference("currentCount", resultType)()
- private val preAvg = AttributeReference("preAvg", resultType)()
- private val currentAvg = AttributeReference("currentAvg", resultType)()
- private val currentMk = AttributeReference("currentMk", resultType)()
+ private val count = AttributeReference("count", resultType)()
+ private val avg = AttributeReference("avg", resultType)()
+ private val mk = AttributeReference("mk", resultType)()
- override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
- currentAvg :: currentMk :: Nil
+ override val aggBufferAttributes = count :: avg :: mk :: Nil
override val initialValues = Seq(
- /* preCount = */ Cast(Literal(0), resultType),
- /* currentCount = */ Cast(Literal(0), resultType),
- /* preAvg = */ Cast(Literal(0), resultType),
- /* currentAvg = */ Cast(Literal(0), resultType),
- /* currentMk = */ Cast(Literal(0), resultType)
+ /* count = */ Cast(Literal(0), resultType),
+ /* avg = */ Cast(Literal(0), resultType),
+ /* mk = */ Cast(Literal(0), resultType)
)
override val updateExpressions = {
+ val value = Cast(child, resultType)
+ val newCount = count + Cast(Literal(1), resultType)
// update average
// avg = avg + (value - avg)/count
- def avgAdd: Expression = {
- currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
- }
+ val newAvg = avg + (value - avg) / newCount
// update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
- def mkAdd: Expression = {
- val delta1 = Cast(child, resultType) - preAvg
- val delta2 = Cast(child, resultType) - currentAvg
- currentMk + (delta1 * delta2)
- }
+ val newMk = mk + (value - avg) * (value - newAvg)
Seq(
- /* preCount = */ If(IsNull(child), preCount, currentCount),
- /* currentCount = */ If(IsNull(child), currentCount,
- Add(currentCount, Cast(Literal(1), resultType))),
- /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
- /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
- /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+ /* count = */ If(IsNull(child), count, newCount),
+ /* avg = */ If(IsNull(child), avg, newAvg),
+ /* mk = */ If(IsNull(child), mk, newMk)
)
}
override val mergeExpressions = {
// count merge
- def countMerge: Expression = {
- currentCount.left + currentCount.right
- }
+ val newCount = count.left + count.right
// average merge
- def avgMerge: Expression = {
- ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right))
/
- (preCount + currentCount.right)
- }
+ val newAvg = ((avg.left * count.left) + (avg.right * count.right)) /
newCount
// update sum of square differences
- def mkMerge: Expression = {
- val avgDelta = currentAvg.right - preAvg
- val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
- (preCount + currentCount.right)
-
- currentMk.left + currentMk.right + mkDelta
+ val newMk = {
+ val avgDelta = avg.right - avg.left
+ val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) /
newCount
+ mk.left + mk.right + mkDelta
}
Seq(
- /* preCount = */ If(IsNull(currentCount.left),
- Cast(Literal(0), resultType), currentCount.left),
- /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
- If(IsNull(currentCount.right), currentCount.left,
countMerge)),
- /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType),
currentAvg.left),
- /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
- If(IsNull(currentAvg.right), currentAvg.left,
avgMerge)),
- /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
- If(IsNull(currentMk.right), currentMk.left, mkMerge))
+ /* count = */ If(IsNull(count.left), count.right,
+ If(IsNull(count.right), count.left, newCount)),
+ /* avg = */ If(IsNull(avg.left), avg.right,
+ If(IsNull(avg.right), avg.left, newAvg)),
+ /* mk = */ If(IsNull(mk.left), mk.right,
+ If(IsNull(mk.right), mk.left, newMk))
)
}
override val evaluateExpression = {
- // when currentCount == 0, return null
- // when currentCount == 1, return 0
- // when currentCount >1
- // stddev_samp = sqrt (currentMk/(currentCount -1))
- // stddev_pop = sqrt (currentMk/currentCount)
- val varCol = {
+ // when count == 0, return null
+ // when count == 1, return 0
+ // when count >1
+ // stddev_samp = sqrt (mk/(count -1))
+ // stddev_pop = sqrt (mk/count)
+ val varCol =
if (isSample) {
- currentMk / Cast((currentCount - Cast(Literal(1), resultType)),
resultType)
- }
- else {
- currentMk / currentCount
+ mk / Cast((count - Cast(Literal(1), resultType)), resultType)
+ } else {
+ mk / count
}
- }
- If(EqualTo(currentCount, Cast(Literal(0), resultType)),
Cast(Literal(null), resultType),
- If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0),
resultType),
+ If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null),
resultType),
+ If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0),
resultType),
Cast(Sqrt(varCol), resultType)))
}
}
@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends
DeclarativeAggregate {
private val sumDataType = resultType
- private val currentSum = AttributeReference("currentSum", sumDataType)()
+ private val sum = AttributeReference("sum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType)
- override val aggBufferAttributes = currentSum :: Nil
+ override val aggBufferAttributes = sum :: Nil
override val initialValues = Seq(
- /* currentSum = */ Literal.create(null, sumDataType)
+ /* sum = */ Literal.create(null, sumDataType)
)
override val updateExpressions = Seq(
- /* currentSum = */
- Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child,
sumDataType)), currentSum))
+ /* sum = */
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
override val mergeExpressions = {
- val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right,
sumDataType))
+ val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
- /* currentSum = */
- Coalesce(Seq(add, currentSum.left))
+ /* sum = */
+ Coalesce(Seq(add, sum.left))
)
}
- override val evaluateExpression = Cast(currentSum, resultType)
+ override val evaluateExpression = Cast(sum, resultType)
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/67e23b39/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e8ee647..4b66069 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,28 +44,42 @@ object GenerateMutableProjection extends
CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
+ val isNull = s"isNull_$i"
+ val value = s"value_$i"
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$isNull = ${evaluationCode.isNull};
+ this.$value = ${evaluationCode.value};
+ """
+ }
+ val updates = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
if (e.dataType.isInstanceOf[DecimalType]) {
// Can't call setNullAt on DecimalType, because we need to keep the
offset
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i,
evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
} else {
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
mutableRow.setNullAt($i);
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i,
evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
}
}
+
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
+ val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
val code = s"""
public Object generate($exprType[] expr) {
@@ -98,6 +112,8 @@ object GenerateMutableProjection extends
CodeGenerator[Seq[Expression], () => Mu
public Object apply(Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allProjections
+ // copy all the results into MutableRow
+ $allUpdates
return mutableRow;
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]