http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala new file mode 100644 index 0000000..ce1cbdc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.types.NullType + +import scala.collection.mutable.ArrayBuffer + +/** + * An iterator used to evaluate aggregate functions. It assumes that input rows + * are already grouped by values of `groupingExpressions`. + */ +private[sql] abstract class SortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends Iterator[InternalRow] { + + /////////////////////////////////////////////////////////////////////////// + // Static fields for this iterator + /////////////////////////////////////////////////////////////////////////// + + protected val aggregateFunctions: Array[AggregateFunction2] = { + var bufferOffset = initialBufferOffset + val functions = new Array[AggregateFunction2](aggregateExpressions.length) + var i = 0 + while (i < aggregateExpressions.length) { + val func = aggregateExpressions(i).aggregateFunction + val funcWithBoundReferences = aggregateExpressions(i).mode match { + case Partial | Complete if !func.isInstanceOf[AlgebraicAggregate] => + // We need to create BoundReferences if the function is not an + // AlgebraicAggregate (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => func + } + // Set bufferOffset for this function. It is important that setting bufferOffset + // happens after all potential bindReference operations because bindReference + // will create a new instance of the function. + funcWithBoundReferences.bufferOffset = bufferOffset + bufferOffset += funcWithBoundReferences.bufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 + } + functions + } + + // All non-algebraic aggregate functions. + protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + aggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // Positions of those non-algebraic aggregate functions in aggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are non-algebraic aggregate functions. + // nonAlgebraicAggregateFunctionPositions will be [1, 2]. + protected val nonAlgebraicAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < aggregateFunctions.length) { + aggregateFunctions(i) match { + case agg: AlgebraicAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray + } + + // This is used to project expressions for the grouping expressions. + protected val groupGenerator = + newMutableProjection(groupingExpressions, inputAttributes)() + + // The underlying buffer shared by all aggregate functions. + protected val buffer: MutableRow = { + // The number of elements of the underlying buffer of this operator. + // All aggregate functions are sharing this underlying buffer and they find their + // buffer values through bufferOffset. + var size = initialBufferOffset + var i = 0 + while (i < aggregateFunctions.length) { + size += aggregateFunctions(i).bufferSchema.length + i += 1 + } + new GenericMutableRow(size) + } + + protected val joinedRow = new JoinedRow4 + + protected val placeholderExpressions = Seq.fill(initialBufferOffset)(NoOp) + + // This projection is used to initialize buffer values for all AlgebraicAggregates. + protected val algebraicInitialProjection = { + val initExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.initialValues + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(initExpressions, Nil)().target(buffer) + } + + /////////////////////////////////////////////////////////////////////////// + // Mutable states + /////////////////////////////////////////////////////////////////////////// + + // The partition key of the current partition. + protected var currentGroupingKey: InternalRow = _ + // The partition key of next partition. + protected var nextGroupingKey: InternalRow = _ + // The first row of next partition. + protected var firstRowInNextGroup: InternalRow = _ + // Indicates if we has new group of rows to process. + protected var hasNewGroup: Boolean = true + + /////////////////////////////////////////////////////////////////////////// + // Private methods + /////////////////////////////////////////////////////////////////////////// + + /** Initializes buffer values for all aggregate functions. */ + protected def initializeBuffer(): Unit = { + algebraicInitialProjection(EmptyRow) + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).initialize(buffer) + i += 1 + } + } + + protected def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + + /** Processes rows in the current group. It will stop when it find a new group. */ + private def processCurrentGroup(): Unit = { + currentGroupingKey = nextGroupingKey + // Now, we will start to find all rows belonging to this group. + // We create a variable to track if we see the next group. + var findNextPartition = false + // firstRowInNextGroup is the first row of this group. We first process it. + processRow(firstRowInNextGroup) + // The search will stop when we see the next group or there is no + // input row left in the iter. + while (inputIter.hasNext && !findNextPartition) { + val currentRow = inputIter.next() + // Get the grouping key based on the grouping expressions. + // For the below compare method, we do not need to make a copy of groupingKey. + val groupingKey = groupGenerator(currentRow) + // Check if the current row belongs the current input row. + currentGroupingKey.equals(groupingKey) + + if (currentGroupingKey == groupingKey) { + processRow(currentRow) + } else { + // We find a new group. + findNextPartition = true + nextGroupingKey = groupingKey.copy() + firstRowInNextGroup = currentRow.copy() + } + } + // We have not seen a new group. It means that there is no new row in the input + // iter. The current group is the last group of the iter. + if (!findNextPartition) { + hasNewGroup = false + } + } + + /////////////////////////////////////////////////////////////////////////// + // Public methods + /////////////////////////////////////////////////////////////////////////// + + override final def hasNext: Boolean = hasNewGroup + + override final def next(): InternalRow = { + if (hasNext) { + // Process the current group. + processCurrentGroup() + // Generate output row for the current group. + val outputRow = generateOutput() + // Initilize buffer values for the next group. + initializeBuffer() + + outputRow + } else { + // no more result + throw new NoSuchElementException + } + } + + /////////////////////////////////////////////////////////////////////////// + // Methods that need to be implemented + /////////////////////////////////////////////////////////////////////////// + + protected def initialBufferOffset: Int + + protected def processRow(row: InternalRow): Unit + + protected def generateOutput(): InternalRow + + /////////////////////////////////////////////////////////////////////////// + // Initialize this iterator + /////////////////////////////////////////////////////////////////////////// + + initialize() +} + +/** + * An iterator only used to group input rows according to values of `groupingExpressions`. + * It assumes that input rows are already grouped by values of `groupingExpressions`. + */ +class GroupingIterator( + groupingExpressions: Seq[NamedExpression], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + Nil, + newMutableProjection, + inputAttributes, + inputIter) { + + private val resultProjection = + newMutableProjection(resultExpressions, groupingExpressions.map(_.toAttribute))() + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Since we only do grouping, there is nothing to do at here. + } + + override protected def generateOutput(): InternalRow = { + resultProjection(currentGroupingKey) + } +} + +/** + * An iterator used to do partial aggregations (for those aggregate functions with mode Partial). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // This projection is used to update buffer values for all AlgebraicAggregates. + private val algebraicUpdateProjection = { + val bufferSchema = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + override protected def initialBufferOffset: Int = 0 + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicUpdateProjection(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We just output the grouping expressions and the underlying buffer. + joinedRow(currentGroupingKey, buffer).copy() + } +} + +/** + * An iterator used to do partial merge aggregations (for those aggregate functions with mode + * PartialMerge). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholderN|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + */ +class PartialMergeSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + private val placeholderAttribtues = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ placeholderAttribtues ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to extract aggregation buffers from the underlying buffer. + // We need it because the underlying buffer has placeholders at its beginning. + private val extractsBufferValues = { + val expressions = aggregateFunctions.flatMap { + case agg => agg.bufferAttributes + } + + newMutableProjection(expressions, inputAttributes)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // We output grouping expressions and aggregation buffers. + joinedRow(currentGroupingKey, extractsBufferValues(buffer)) + } +} + +/** + * An iterator used to do final aggregations (for those aggregate functions with mode + * Final). It assumes that input rows are already grouped by values of + * `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|aggregationBuffer1|...|aggregationBufferN| + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder N|aggregationBuffer1|...|aggregationBufferN| + * Every placeholder is for a grouping expression. + * The actual buffers are stored after placeholderN. + * The reason that we have placeholders at here is to make our underlying buffer have the same + * length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalSortAggregationIterator( + groupingExpressions: Seq[NamedExpression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + aggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = new GenericMutableRow(aggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = + newMutableProjection( + resultExpressions, groupingExpressions.map(_.toAttribute) ++ aggregateAttributes)() + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // This projection is used to merge buffer values for all AlgebraicAggregates. + private val algebraicMergeProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override protected def initialBufferOffset: Int = groupingExpressions.length + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + // Process all algebraic aggregate functions. + algebraicMergeProjection.target(buffer)(joinedRow(buffer, row)) + // Process all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + nonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +} + +/** + * An iterator used to do both final aggregations (for those aggregate functions with mode + * Final) and complete aggregations (for those aggregate functions with mode Complete). + * It assumes that input rows are already grouped by values of `groupingExpressions`. + * The format of its input rows is: + * |groupingExpr1|...|groupingExprN|col1|...|colM|aggregationBuffer1|...|aggregationBufferN| + * col1 to colM are columns used by aggregate functions with Complete mode. + * aggregationBuffer1 to aggregationBufferN are buffers used by aggregate functions with + * Final mode. + * + * The format of its internal buffer is: + * |placeholder1|...|placeholder(N+M)|aggregationBuffer1|...|aggregationBuffer(N+M)| + * The first N placeholders represent slots of grouping expressions. + * Then, next M placeholders represent slots of col1 to colM. + * For aggregation buffers, first N aggregation buffers are used by N aggregate functions with + * mode Final. Then, the last M aggregation buffers are used by M aggregate functions with mode + * Complete. The reason that we have placeholders at here is to make our underlying buffer + * have the same length with a input row. + * + * The format of its output rows is represented by the schema of `resultExpressions`. + */ +class FinalAndCompleteSortAggregationIterator( + override protected val initialBufferOffset: Int, + groupingExpressions: Seq[NamedExpression], + finalAggregateExpressions: Seq[AggregateExpression2], + finalAggregateAttributes: Seq[Attribute], + completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + inputAttributes: Seq[Attribute], + inputIter: Iterator[InternalRow]) + extends SortAggregationIterator( + groupingExpressions, + // TODO: document the ordering + finalAggregateExpressions ++ completeAggregateExpressions, + newMutableProjection, + inputAttributes, + inputIter) { + + // The result of aggregate functions. + private val aggregateResult: MutableRow = + new GenericMutableRow(completeAggregateAttributes.length + finalAggregateAttributes.length) + + // The projection used to generate the output rows of this operator. + // This is only used when we are generating final results of aggregate functions. + private val resultProjection = { + val inputSchema = + groupingExpressions.map(_.toAttribute) ++ + finalAggregateAttributes ++ + completeAggregateAttributes + newMutableProjection(resultExpressions, inputSchema)() + } + + private val offsetAttributes = + Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)()) + + // All aggregate functions with mode Final. + private val finalAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](finalAggregateExpressions.length) + var i = 0 + while (i < finalAggregateExpressions.length) { + functions(i) = aggregateFunctions(i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Final. + private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + finalAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // All aggregate functions with mode Complete. + private val completeAggregateFunctions: Array[AggregateFunction2] = { + val functions = new Array[AggregateFunction2](completeAggregateExpressions.length) + var i = 0 + while (i < completeAggregateExpressions.length) { + functions(i) = aggregateFunctions(finalAggregateFunctions.length + i) + i += 1 + } + functions + } + + // All non-algebraic aggregate functions with mode Complete. + private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = { + completeAggregateFunctions.collect { + case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func + }.toArray + } + + // This projection is used to merge buffer values for all AlgebraicAggregates with mode + // Final. + private val finalAlgebraicMergeProjection = { + val numCompleteOffsetAttributes = + completeAggregateFunctions.map(_.bufferAttributes.length).sum + val completeOffsetAttributes = + Seq.fill(numCompleteOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp) + + val bufferSchemata = + offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } ++ completeOffsetAttributes + val mergeExpressions = + placeholderExpressions ++ finalAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions + + newMutableProjection(mergeExpressions, bufferSchemata)() + } + + // This projection is used to update buffer values for all AlgebraicAggregates with mode + // Complete. + private val completeAlgebraicUpdateProjection = { + val numFinalOffsetAttributes = finalAggregateFunctions.map(_.bufferAttributes.length).sum + val finalOffsetAttributes = + Seq.fill(numFinalOffsetAttributes)(AttributeReference("placeholder", NullType)()) + val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp) + + val bufferSchema = + offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } + val updateExpressions = + placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp) + } + newMutableProjection(updateExpressions, bufferSchema ++ inputAttributes)().target(buffer) + } + + // This projection is used to evaluate all AlgebraicAggregates. + private val algebraicEvalProjection = { + val bufferSchemata = + offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.bufferAttributes + case agg: AggregateFunction2 => agg.bufferAttributes + } ++ offsetAttributes ++ aggregateFunctions.flatMap { + case ae: AlgebraicAggregate => ae.cloneBufferAttributes + case agg: AggregateFunction2 => agg.cloneBufferAttributes + } + val evalExpressions = aggregateFunctions.map { + case ae: AlgebraicAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + + newMutableProjection(evalExpressions, bufferSchemata)() + } + + override def initialize(): Unit = { + if (inputIter.hasNext) { + initializeBuffer() + val currentRow = inputIter.next().copy() + // partitionGenerator is a mutable projection. Since we need to track nextGroupingKey, + // we are making a copy at here. + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + if (groupingExpressions.isEmpty) { + // If there is no grouping expression, we need to generate a single row as the output. + initializeBuffer() + // Right now, the buffer only contains initial buffer values. Because + // merging two buffers with initial values will generate a row that + // still store initial values. We set the currentRow as the copy of the current buffer. + val currentRow = buffer.copy() + nextGroupingKey = groupGenerator(currentRow).copy() + firstRowInNextGroup = currentRow + } else { + // This iter is an empty one. + hasNewGroup = false + } + } + } + + override protected def processRow(row: InternalRow): Unit = { + val input = joinedRow(buffer, row) + // For all aggregate functions with mode Complete, update buffers. + completeAlgebraicUpdateProjection(input) + var i = 0 + while (i < completeNonAlgebraicAggregateFunctions.length) { + completeNonAlgebraicAggregateFunctions(i).update(buffer, row) + i += 1 + } + + // For all aggregate functions with mode Final, merge buffers. + finalAlgebraicMergeProjection.target(buffer)(input) + i = 0 + while (i < finalNonAlgebraicAggregateFunctions.length) { + finalNonAlgebraicAggregateFunctions(i).merge(buffer, row) + i += 1 + } + } + + override protected def generateOutput(): InternalRow = { + // Generate results for all algebraic aggregate functions. + algebraicEvalProjection.target(aggregateResult)(buffer) + // Generate results for all non-algebraic aggregate functions. + var i = 0 + while (i < nonAlgebraicAggregateFunctions.length) { + aggregateResult.update( + nonAlgebraicAggregateFunctionPositions(i), + nonAlgebraicAggregateFunctions(i).eval(buffer)) + i += 1 + } + + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) + } +}
http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala new file mode 100644 index 0000000..1cb2771 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructType, MapType, ArrayType} + +/** + * Utility functions used by the query planner to convert our plan to new aggregation code path. + */ +object Utils { + // Right now, we do not support complex types in the grouping key schema. + private def supportsGroupingKeySchema(aggregate: Aggregate): Boolean = { + val hasComplexTypes = aggregate.groupingExpressions.map(_.dataType).exists { + case array: ArrayType => true + case map: MapType => true + case struct: StructType => true + case _ => false + } + + !hasComplexTypes + } + + private def tryConvert(plan: LogicalPlan): Option[Aggregate] = plan match { + case p: Aggregate if supportsGroupingKeySchema(p) => + val converted = p.transformExpressionsDown { + case expressions.Average(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Average(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Count(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(child), + mode = aggregate.Complete, + isDistinct = false) + + // We do not support multiple COUNT DISTINCT columns for now. + case expressions.CountDistinct(children) if children.length == 1 => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Count(children.head), + mode = aggregate.Complete, + isDistinct = true) + + case expressions.First(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.First(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Last(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Last(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Max(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Max(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Min(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Min(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.Sum(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = false) + + case expressions.SumDistinct(child) => + aggregate.AggregateExpression2( + aggregateFunction = aggregate.Sum(child), + mode = aggregate.Complete, + isDistinct = true) + } + // Check if there is any expressions.AggregateExpression1 left. + // If so, we cannot convert this plan. + val hasAggregateExpression1 = converted.aggregateExpressions.exists { expr => + // For every expressions, check if it contains AggregateExpression1. + expr.find { + case agg: expressions.AggregateExpression1 => true + case other => false + }.isDefined + } + + // Check if there are multiple distinct columns. + val aggregateExpressions = converted.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg + } + }.toSet.toSeq + val functionsWithDistinct = aggregateExpressions.filter(_.isDistinct) + val hasMultipleDistinctColumnSets = + if (functionsWithDistinct.map(_.aggregateFunction.children).distinct.length > 1) { + true + } else { + false + } + + if (!hasAggregateExpression1 && !hasMultipleDistinctColumnSets) Some(converted) else None + + case other => None + } + + private def checkInvalidAggregateFunction2(aggregate: Aggregate): Unit = { + // If the plan cannot be converted, we will do a final round check to if the original + // logical.Aggregate contains both AggregateExpression1 and AggregateExpression2. If so, + // we need to throw an exception. + val aggregateFunction2s = aggregate.aggregateExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression2 => agg.aggregateFunction + } + }.distinct + if (aggregateFunction2s.nonEmpty) { + // For functions implemented based on the new interface, prepare a list of function names. + val invalidFunctions = { + if (aggregateFunction2s.length > 1) { + s"${aggregateFunction2s.tail.map(_.nodeName).mkString(",")} " + + s"and ${aggregateFunction2s.head.nodeName} are" + } else { + s"${aggregateFunction2s.head.nodeName} is" + } + } + val errorMessage = + s"${invalidFunctions} implemented based on the new Aggregate Function " + + s"interface and it cannot be used with functions implemented based on " + + s"the old Aggregate Function interface." + throw new AnalysisException(errorMessage) + } + } + + def tryConvert( + plan: LogicalPlan, + useNewAggregation: Boolean, + codeGenEnabled: Boolean): Option[Aggregate] = plan match { + case p: Aggregate if useNewAggregation && codeGenEnabled => + val converted = tryConvert(p) + if (converted.isDefined) { + converted + } else { + checkInvalidAggregateFunction2(p) + None + } + case p: Aggregate => + checkInvalidAggregateFunction2(p) + None + case other => None + } + + def planAggregateWithoutDistinct( + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + // 1. Create an Aggregate Operator for partial aggregations. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + val partialAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Partial, isDistinct) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + namedGroupingExpressions.map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for final aggregations. + val finalAggregateExpressions = aggregateExpressions.map { + case AggregateExpression2(aggregateFunction, mode, isDistinct) => + AggregateExpression2(aggregateFunction, Final, isDistinct) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAggregate = Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + rewrittenResultExpressions, + partialAggregate) + + finalAggregate :: Nil + } + + def planAggregateWithOneDistinct( + groupingExpressions: Seq[Expression], + functionsWithDistinct: Seq[AggregateExpression2], + functionsWithoutDistinct: Seq[AggregateExpression2], + aggregateFunctionMap: Map[(AggregateFunction2, Boolean), Attribute], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + // 1. Create an Aggregate Operator for partial aggregations. + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + // It is safe to call head at here since functionsWithDistinct has at least one + // AggregateExpression2. + val distinctColumnExpressions = + functionsWithDistinct.head.aggregateFunction.children + val namedDistinctColumnExpressions = distinctColumnExpressions.map { + case ne: NamedExpression => ne -> ne + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val distinctColumnExpressionMap = namedDistinctColumnExpressions.toMap + val distinctColumnAttributes = namedDistinctColumnExpressions.map(_._2.toAttribute) + + val partialAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Partial, false) + } + val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg => + agg.aggregateFunction.bufferAttributes + } + val partialAggregate = + Aggregate2Sort( + None: Option[Seq[Expression]], + (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2), + partialAggregateExpressions, + partialAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialAggregateAttributes, + child) + + // 2. Create an Aggregate Operator for partial merge aggregations. + val partialMergeAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, PartialMerge, false) + } + val partialMergeAggregateAttributes = + partialMergeAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val partialMergeAggregate = + Aggregate2Sort( + Some(namedGroupingAttributes), + namedGroupingAttributes ++ distinctColumnAttributes, + partialMergeAggregateExpressions, + partialMergeAggregateAttributes, + namedGroupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateAttributes, + partialAggregate) + + // 3. Create an Aggregate Operator for partial merge aggregations. + val finalAggregateExpressions = functionsWithoutDistinct.map { + case AggregateExpression2(aggregateFunction, mode, _) => + AggregateExpression2(aggregateFunction, Final, false) + } + val finalAggregateAttributes = + finalAggregateExpressions.map { + expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { + // Children of an AggregateFunction with DISTINCT keyword has already + // been evaluated. At here, we need to replace original children + // to AttributeReferences. + case agg @ AggregateExpression2(aggregateFunction, mode, isDistinct) => + val rewrittenAggregateFunction = aggregateFunction.transformDown { + case expr if distinctColumnExpressionMap.contains(expr) => + distinctColumnExpressionMap(expr).toAttribute + }.asInstanceOf[AggregateFunction2] + // We rewrite the aggregate function to a non-distinct aggregation because + // its input will have distinct arguments. + val rewrittenAggregateExpression = + AggregateExpression2(rewrittenAggregateFunction, Complete, false) + + val aggregateFunctionAttribute = aggregateFunctionMap(agg.aggregateFunction, isDistinct) + (rewrittenAggregateExpression -> aggregateFunctionAttribute) + }.unzip + + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transform { + case agg: AggregateExpression2 => + aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct).toAttribute + case expression => + // We do not rely on the equality check at here since attributes may + // different cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val finalAndCompleteAggregate = FinalAndCompleteAggregate2Sort( + namedGroupingAttributes ++ distinctColumnAttributes, + namedGroupingAttributes, + finalAggregateExpressions, + finalAggregateAttributes, + completeAggregateExpressions, + completeAggregateAttributes, + rewrittenResultExpressions, + partialMergeAggregate) + + finalAndCompleteAggregate :: Nil + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala new file mode 100644 index 0000000..6c49a90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/aggregate/udaf.scala @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.aggregate + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction2 +import org.apache.spark.sql.types._ +import org.apache.spark.sql.Row + +/** + * The abstract class for implementing user-defined aggregate function. + */ +abstract class UserDefinedAggregateFunction extends Serializable { + + /** + * A [[StructType]] represents data types of input arguments of this aggregate function. + * For example, if a [[UserDefinedAggregateFunction]] expects two input arguments + * with type of [[DoubleType]] and [[LongType]], the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * input argument. Users can choose names to identify the input arguments. + */ + def inputSchema: StructType + + /** + * A [[StructType]] represents data types of values in the aggregation buffer. + * For example, if a [[UserDefinedAggregateFunction]]'s buffer has two values + * (i.e. two intermediate values) with type of [[DoubleType]] and [[LongType]], + * the returned [[StructType]] will look like + * + * ``` + * StructType(Seq(StructField("doubleInput", DoubleType), StructField("longInput", LongType))) + * ``` + * + * The name of a field of this [[StructType]] is only used to identify the corresponding + * buffer value. Users can choose names to identify the input arguments. + */ + def bufferSchema: StructType + + /** + * The [[DataType]] of the returned value of this [[UserDefinedAggregateFunction]]. + */ + def returnDataType: DataType + + /** Indicates if this function is deterministic. */ + def deterministic: Boolean + + /** + * Initializes the given aggregation buffer. Initial values set by this method should satisfy + * the condition that when merging two buffers with initial values, the new buffer should + * still store initial values. + */ + def initialize(buffer: MutableAggregationBuffer): Unit + + /** Updates the given aggregation buffer `buffer` with new input data from `input`. */ + def update(buffer: MutableAggregationBuffer, input: Row): Unit + + /** Merges two aggregation buffers and stores the updated buffer values back in `buffer1`. */ + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit + + /** + * Calculates the final result of this [[UserDefinedAggregateFunction]] based on the given + * aggregation buffer. + */ + def evaluate(buffer: Row): Any +} + +private[sql] abstract class AggregationBuffer( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int) + extends Row { + + override def length: Int = toCatalystConverters.length + + protected val offsets: Array[Int] = { + val newOffsets = new Array[Int](length) + var i = 0 + while (i < newOffsets.length) { + newOffsets(i) = bufferOffset + i + i += 1 + } + newOffsets + } +} + +/** + * A Mutable [[Row]] representing an mutable aggregation buffer. + */ +class MutableAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingBuffer: MutableRow) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingBuffer(offsets(i))) + } + + def update(i: Int, value: Any): Unit = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not update ${i}th value in this buffer because it only has $length values.") + } + underlyingBuffer.update(offsets(i), toCatalystConverters(i)(value)) + } + + override def copy(): MutableAggregationBuffer = { + new MutableAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingBuffer) + } +} + +/** + * A [[Row]] representing an immutable aggregation buffer. + */ +class InputAggregationBuffer private[sql] ( + toCatalystConverters: Array[Any => Any], + toScalaConverters: Array[Any => Any], + bufferOffset: Int, + var underlyingInputBuffer: Row) + extends AggregationBuffer(toCatalystConverters, toScalaConverters, bufferOffset) { + + override def get(i: Int): Any = { + if (i >= length || i < 0) { + throw new IllegalArgumentException( + s"Could not access ${i}th value in this buffer because it only has $length values.") + } + toScalaConverters(i)(underlyingInputBuffer(offsets(i))) + } + + override def copy(): InputAggregationBuffer = { + new InputAggregationBuffer( + toCatalystConverters, + toScalaConverters, + bufferOffset, + underlyingInputBuffer) + } +} + +/** + * The internal wrapper used to hook a [[UserDefinedAggregateFunction]] `udaf` in the + * internal aggregation code path. + * @param children + * @param udaf + */ +case class ScalaUDAF( + children: Seq[Expression], + udaf: UserDefinedAggregateFunction) + extends AggregateFunction2 with Logging { + + require( + children.length == udaf.inputSchema.length, + s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + + s"but ${children.length} are provided.") + + override def nullable: Boolean = true + + override def dataType: DataType = udaf.returnDataType + + override def deterministic: Boolean = udaf.deterministic + + override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) + + override val bufferSchema: StructType = udaf.bufferSchema + + override val bufferAttributes: Seq[AttributeReference] = bufferSchema.toAttributes + + override lazy val cloneBufferAttributes = bufferAttributes.map(_.newInstance()) + + val childrenSchema: StructType = { + val inputFields = children.zipWithIndex.map { + case (child, index) => + StructField(s"input$index", child.dataType, child.nullable, Metadata.empty) + } + StructType(inputFields) + } + + lazy val inputProjection = { + val inputAttributes = childrenSchema.toAttributes + log.debug( + s"Creating MutableProj: $children, inputSchema: $inputAttributes.") + try { + GenerateMutableProjection.generate(children, inputAttributes)() + } catch { + case e: Exception => + log.error("Failed to generate mutable projection, fallback to interpreted", e) + new InterpretedMutableProjection(children, inputAttributes) + } + } + + val inputToScalaConverters: Any => Any = + CatalystTypeConverters.createToScalaConverter(childrenSchema) + + val bufferValuesToCatalystConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToCatalystConverter(field.dataType) + } + + val bufferValuesToScalaConverters: Array[Any => Any] = bufferSchema.fields.map { field => + CatalystTypeConverters.createToScalaConverter(field.dataType) + } + + lazy val inputAggregateBuffer: InputAggregationBuffer = + new InputAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + lazy val mutableAggregateBuffer: MutableAggregationBuffer = + new MutableAggregationBuffer( + bufferValuesToCatalystConverters, + bufferValuesToScalaConverters, + bufferOffset, + null) + + + override def initialize(buffer: MutableRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.initialize(mutableAggregateBuffer) + } + + override def update(buffer: MutableRow, input: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer + + udaf.update( + mutableAggregateBuffer, + inputToScalaConverters(inputProjection(input)).asInstanceOf[Row]) + } + + override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { + mutableAggregateBuffer.underlyingBuffer = buffer1 + inputAggregateBuffer.underlyingInputBuffer = buffer2 + + udaf.merge(mutableAggregateBuffer, inputAggregateBuffer) + } + + override def eval(buffer: InternalRow = null): Any = { + inputAggregateBuffer.underlyingInputBuffer = buffer + + udaf.evaluate(inputAggregateBuffer) + } + + override def toString: String = { + s"""${udaf.getClass.getSimpleName}(${children.mkString(",")})""" + } + + override def nodeName: String = udaf.getClass.getSimpleName +} http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 28159cb..bfeecbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2420,7 +2420,7 @@ object functions { * @since 1.5.0 */ def callUDF(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } /** @@ -2449,7 +2449,7 @@ object functions { exprs(i) = cols(i).expr i += 1 } - UnresolvedFunction(udfName, exprs) + UnresolvedFunction(udfName, exprs, isDistinct = false) } } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index beee101..ab8dce6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -23,6 +23,7 @@ import java.sql.Timestamp import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.errors.DialectException +import org.apache.spark.sql.execution.aggregate.Aggregate2Sort import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ @@ -204,6 +205,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { var hasGeneratedAgg = false df.queryExecution.executedPlan.foreach { case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true + case newAggregate: Aggregate2Sort => hasGeneratedAgg = true case _ => } if (!hasGeneratedAgg) { @@ -285,7 +287,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Aggregate with Code generation handling all null values testCodeGen( "SELECT sum('a'), avg('a'), count(null) FROM testData", - Row(0, null, 0) :: Nil) + Row(null, null, 0) :: Nil) } finally { sqlContext.dropTempTable("testData3x") sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue) http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3dd2413..3d71deb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.SparkFunSuite import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext._ @@ -30,6 +31,20 @@ import org.apache.spark.sql.{Row, SQLConf, execution} class PlannerSuite extends SparkFunSuite { + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { + val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) + val planned = + plannedOption.getOrElse( + fail(s"Could query play aggregation query $query. Is it an aggregation query?")) + val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + + // For the new aggregation code path, there will be three aggregate operator for + // distinct aggregations. + assert( + aggregations.size == 2 || aggregations.size == 3, + s"The plan of query $query does not have partial aggregations.") + } + test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head @@ -42,23 +57,18 @@ class PlannerSuite extends SparkFunSuite { test("count is partially aggregated") { val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } - - assert(aggregations.size === 2) + testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testPartialAggregationPlan(query) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 31a49a3..24a758f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -833,6 +833,7 @@ abstract class HiveWindowFunctionQueryFileBaseSuite "windowing_adjust_rowcontainer_sz" ) + // Only run those query tests in the realWhileList (do not try other ignored query files). override def testCases: Seq[(String, File)] = super.testCases.filter { case (name, _) => realWhiteList.contains(name) } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala index f458567..1fe4fe9 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import org.apache.spark.sql.SQLConf import org.apache.spark.sql.hive.test.TestHive @@ -159,4 +161,9 @@ class SortMergeCompatibilitySuite extends HiveCompatibilitySuite { "join_reorder4", "join_star" ) + + // Only run those query tests in the realWhileList (do not try other ignored query files). + override def testCases: Seq[(String, File)] = super.testCases.filter { + case (name, _) => realWhiteList.contains(name) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index cec7685..4cdb83c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -451,6 +451,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) with Logging { DataSinks, Scripts, HashAggregation, + Aggregation, LeftSemiJoin, HashJoin, BasicOperators, http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index f557450..8518e33 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1464,9 +1464,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4d23c70..3259b50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -409,7 +409,7 @@ private[hive] case class HiveWindowFunction( private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = AbstractGenericUDAFResolver @@ -441,7 +441,7 @@ private[hive] case class HiveGenericUDAF( /** It is used as a wrapper for the hive functions which uses UDAF interface */ private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression]) extends AggregateExpression1 with HiveInspectors { type UDFType = UDAF @@ -550,9 +550,9 @@ private[hive] case class HiveGenericUDTF( private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], - base: AggregateExpression, + base: AggregateExpression1, isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction + extends AggregateFunction1 with HiveInspectors { def this() = this(null, null, null) http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java new file mode 100644 index 0000000..5c9d0e9 --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class MyDoubleAvg extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleAvg() { + List<StructField> inputfields = new ArrayList<StructField>(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List<StructField> bufferFields = new ArrayList<StructField>(); + bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true)); + bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + buffer.update(1, 0L); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + buffer.update(1, 1L); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + buffer.update(1, buffer.getLong(1) + 1L); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + buffer1.update(1, buffer2.getLong(1)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0) / buffer.getLong(1) + 100.0; + } + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java new file mode 100644 index 0000000..1d4587a --- /dev/null +++ b/sql/hive/src/test/java/test/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.hive.aggregate; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.spark.sql.expressions.aggregate.MutableAggregationBuffer; +import org.apache.spark.sql.expressions.aggregate.UserDefinedAggregateFunction; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.Row; + +public class MyDoubleSum extends UserDefinedAggregateFunction { + + private StructType _inputDataType; + + private StructType _bufferSchema; + + private DataType _returnDataType; + + public MyDoubleSum() { + List<StructField> inputfields = new ArrayList<StructField>(); + inputfields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); + _inputDataType = DataTypes.createStructType(inputfields); + + List<StructField> bufferFields = new ArrayList<StructField>(); + bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); + _bufferSchema = DataTypes.createStructType(bufferFields); + + _returnDataType = DataTypes.DoubleType; + } + + @Override public StructType inputSchema() { + return _inputDataType; + } + + @Override public StructType bufferSchema() { + return _bufferSchema; + } + + @Override public DataType returnDataType() { + return _returnDataType; + } + + @Override public boolean deterministic() { + return true; + } + + @Override public void initialize(MutableAggregationBuffer buffer) { + buffer.update(0, null); + } + + @Override public void update(MutableAggregationBuffer buffer, Row input) { + if (!input.isNullAt(0)) { + if (buffer.isNullAt(0)) { + buffer.update(0, input.getDouble(0)); + } else { + Double newValue = input.getDouble(0) + buffer.getDouble(0); + buffer.update(0, newValue); + } + } + } + + @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { + if (!buffer2.isNullAt(0)) { + if (buffer1.isNullAt(0)) { + buffer1.update(0, buffer2.getDouble(0)); + } else { + Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); + buffer1.update(0, newValue); + } + } + } + + @Override public Object evaluate(Row buffer) { + if (buffer.isNullAt(0)) { + return null; + } else { + return buffer.getDouble(0); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada new file mode 100644 index 0000000..573541a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-0-50131c0ba7b7a6b65c789a5a8497bada @@ -0,0 +1 @@ +0 http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 new file mode 100644 index 0000000..44b2a42 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-1-11eb3cc5216d5446f4165007203acc47 @@ -0,0 +1 @@ +unhex(str) - Converts hexadecimal argument to binary http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 new file mode 100644 index 0000000..97af3b8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-2-a660886085b8651852b9b77934848ae4 @@ -0,0 +1,14 @@ +unhex(str) - Converts hexadecimal argument to binary +Performs the inverse operation of HEX(str). That is, it interprets +each pair of hexadecimal digits in the argument as a number and +converts it to the byte representation of the number. The +resulting characters are returned as a binary string. + +Example: +> SELECT DECODE(UNHEX('4D7953514C'), 'UTF-8') from src limit 1; +'MySQL' + +The characters in the argument string must be legal hexadecimal +digits: '0' .. '9', 'A' .. 'F', 'a' .. 'f'. If UNHEX() encounters +any nonhexadecimal digits in the argument, it returns NULL. Also, +if there are an odd number of characters a leading 0 is appended. http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e new file mode 100644 index 0000000..b4a6f2b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-3-4b2cf4050af229fde91ab53fd9f3af3e @@ -0,0 +1 @@ +MySQL 1267 a -4 http://git-wip-us.apache.org/repos/asf/spark/blob/c03299a1/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 new file mode 100644 index 0000000..3a67ada --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_unhex-4-7d3e094f139892ecef17de3fd63ca3c3 @@ -0,0 +1 @@ +NULL NULL NULL --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
