Github user clockfly commented on a diff in the pull request:
https://github.com/apache/spark/pull/14562#discussion_r75039312
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
---
@@ -219,111 +219,125 @@ sealed abstract class AggregateFunction extends
Expression with ImplicitCastInpu
}
/**
- * API for aggregation functions that are expressed in terms of imperative
initialize(), update(),
- * and merge() functions which operate on Row-based aggregation buffers.
- *
- * Within these functions, code should access fields of the mutable
aggregation buffer by adding the
- * bufferSchema-relative field number to `mutableAggBufferOffset` then
using this new field number
- * to access the buffer Row. This is necessary because this aggregation
function's buffer is
- * embedded inside of a larger shared aggregation buffer when an
aggregation operator evaluates
- * multiple aggregate functions at the same time.
- *
- * We need to perform similar field number arithmetic when merging
multiple intermediate
- * aggregate buffers together in `merge()` (in this case, use
`inputAggBufferOffset` when accessing
- * the input buffer).
- *
- * Correct ImperativeAggregate evaluation depends on the correctness of
`mutableAggBufferOffset` and
- * `inputAggBufferOffset`, but not on the correctness of the attribute ids
in `aggBufferAttributes`
- * and `inputAggBufferAttributes`.
+ * API for aggregation functions that are expressed in terms of imperative
doInitialize(),
+ * doUpdate(), doMerge() and doComplete() functions which operate on
Row-based aggregation buffers.
*/
abstract class ImperativeAggregate extends AggregateFunction with
CodegenFallback {
+ // Although `mutableBufferRow` and `inputBufferRow` are 2 mutable fields
in `ImperativeAggregate`,
+ // they can only be set once, thus make `ImperativeAggregate` kind of
immutable and stateless.
+
+ /**
+ * The aggregation operator keeps a large shared mutable buffer row for
all aggregate functions,
+ * each aggregate function should only access a slice of this shared
buffer.
+ */
+ private var mutableBufferRow: SlicedMutableRow = _
+
+ /**
+ * During partial aggregation, the input buffer row to be merged is
shared among all aggregate
+ * functions, each aggregate function should only access a slice of this
input buffer.
+ */
+ private var inputBufferRow: SlicedInternalRow = _
+
/**
- * The offset of this function's first buffer value in the underlying
shared mutable aggregation
- * buffer.
+ * Set the offset of this function's start buffer value in the
underlying shared mutable
+ * aggregation buffer.
*
* For example, we have two aggregate functions `avg(x)` and `avg(y)`,
which share the same
- * aggregation buffer. In this shared buffer, the position of the first
buffer value of `avg(x)`
- * will be 0 and the position of the first buffer value of `avg(y)` will
be 2:
+ * aggregation buffer. In this shared buffer, the position of the start
buffer value of `avg(x)`
+ * will be 0 and the position of the start buffer value of `avg(y)` will
be 2:
* {{{
- * avg(x) mutableAggBufferOffset = 0
+ * avg(x) mutable buffer offset is 0
* |
* v
* +--------+--------+--------+--------+
* | sum1 | count1 | sum2 | count2 |
* +--------+--------+--------+--------+
* ^
* |
- * avg(y) mutableAggBufferOffset = 2
+ * avg(y) mutable buffer offset is 2
* }}}
*/
- protected val mutableAggBufferOffset: Int
-
- /**
- * Returns a copy of this ImperativeAggregate with an updated
mutableAggBufferOffset.
- * This new copy's attributes may have different ids than the original.
- */
- def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int):
ImperativeAggregate
+ final def setMutableBufferOffset(offset: Int): Unit = {
+ assert(mutableBufferRow == null)
+ mutableBufferRow = new SlicedMutableRow(offset,
aggBufferAttributes.length)
+ }
/**
- * The offset of this function's start buffer value in the underlying
shared input aggregation
+ * Set the offset of this function's start buffer value in the
underlying shared input aggregation
* buffer. An input aggregation buffer is used when we merge two
aggregation buffers together in
- * the `update()` function and is immutable (we merge an input
aggregation buffer and a mutable
+ * the `merge()` function and is immutable (we merge an input
aggregation buffer and a mutable
* aggregation buffer and then store the new buffer values to the
mutable aggregation buffer).
*
* An input aggregation buffer may contain extra fields, such as
grouping keys, at its start, so
- * mutableAggBufferOffset and inputAggBufferOffset are often different.
+ * mutable buffer offset and input buffer offset are often different.
*
* For example, say we have a grouping expression, `key`, and two
aggregate functions,
- * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the
position of the first
- * buffer value of `avg(x)` will be 1 and the position of the first
buffer value of `avg(y)`
+ * `avg(x)` and `avg(y)`. In the shared input aggregation buffer, the
position of the start
+ * buffer value of `avg(x)` will be 1 and the position of the start
buffer value of `avg(y)`
* will be 3 (position 0 is used for the value of `key`):
* {{{
- * avg(x) inputAggBufferOffset = 1
+ * avg(x) input buffer offset is 1
* |
* v
* +--------+--------+--------+--------+--------+
* | key | sum1 | count1 | sum2 | count2 |
* +--------+--------+--------+--------+--------+
* ^
* |
- * avg(y) inputAggBufferOffset = 3
+ * avg(y) input buffer offset is 3
* }}}
*/
- protected val inputAggBufferOffset: Int
+ final def setInputBufferOffset(offset: Int): Unit = {
+ assert(inputBufferRow == null)
+ inputBufferRow = new SlicedInternalRow(offset,
aggBufferAttributes.length)
+ }
- /**
- * Returns a copy of this ImperativeAggregate with an updated
mutableAggBufferOffset.
- * This new copy's attributes may have different ids than the original.
- */
- def withNewInputAggBufferOffset(newInputAggBufferOffset: Int):
ImperativeAggregate
+ final def initialize(mutableAggBuffer: MutableRow): Unit = {
+ doInitialize(mutableBufferRow.target(mutableAggBuffer))
+ }
+
+ final def update(mutableAggBuffer: MutableRow, inputRow: InternalRow):
Unit = {
+ doUpdate(mutableBufferRow.target(mutableAggBuffer), inputRow)
+ }
+
+ final def merge(mutableAggBuffer: MutableRow, inputAggBuffer:
InternalRow): Unit = {
+ doMerge(mutableBufferRow.target(mutableAggBuffer),
inputBufferRow.target(inputAggBuffer))
+ }
+
+ final override def eval(aggBuffer: InternalRow): Any = {
+ assert(aggBuffer.isInstanceOf[MutableRow])
+ doEval(mutableBufferRow.target(aggBuffer.asInstanceOf[MutableRow]))
+ }
+
+ final def newInstance(): ImperativeAggregate = {
+
makeCopy(mapProductIterator(_.asInstanceOf[AnyRef])).asInstanceOf[ImperativeAggregate]
--- End diff --
Where is `mapProductIterator` defined?
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]