Repository: flink Updated Branches: refs/heads/master 46ad40588 -> b59c81bc4
[FLINK-2210] Table API support for aggregation on columns with null values Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b59c81bc Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b59c81bc Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b59c81bc Branch: refs/heads/master Commit: b59c81bc41f0fc4ade5359dfdf42549a76d412fa Parents: 46ad405 Author: Shiti <ssaxena....@gmail.com> Authored: Mon Jun 15 00:29:02 2015 +0530 Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com> Committed: Tue Jun 16 18:38:48 2015 +0200 ---------------------------------------------------------------------- .../table/codegen/ExpressionCodeGenerator.scala | 19 +++++++ .../api/table/expressions/aggregations.scala | 2 +- .../api/table/expressions/comparison.scala | 8 +++ .../runtime/ExpressionAggregateFunction.scala | 5 +- .../scala/table/test/AggregationsITCase.scala | 58 +++++++++++++++++++- 5 files changed, 88 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala index 49f7600..e109574 100644 --- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala +++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/codegen/ExpressionCodeGenerator.scala @@ -489,6 +489,25 @@ abstract class ExpressionCodeGenerator[R]( """.stripMargin } + case NumericIsNotNull(child) => + val childCode = generateExpression(child) + if (nullCheck) { + childCode.code + + s""" + |boolean $nullTerm = ${childCode.nullTerm}; + |if ($nullTerm) { + | 0; + |} else { + | $resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0; + |} + """.stripMargin + } else { + childCode.code + + s""" + |$resultTpe $resultTerm = ${childCode.resultTerm} != null ? 1 : 0; + """.stripMargin + } + case _ => throw new ExpressionException("Could not generate code for expression " + expr) } http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala index 08e319d..a762f66 100644 --- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala +++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/aggregations.scala @@ -89,7 +89,7 @@ case class Count(child: Expression) extends Aggregation { case class Avg(child: Expression) extends Aggregation { override def toString = s"($child).avg" - override def getIntermediateFields: Seq[Expression] = Seq(child, Literal(1)) + override def getIntermediateFields: Seq[Expression] = Seq(child, NumericIsNotNull(child)) // This is just sweet. Use our own AST representation and let the code generator do // our dirty work. override def getFinalField(inputs: Seq[Expression]): Expression = http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala index 687ea7a..c60acf9 100644 --- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala +++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/expressions/comparison.scala @@ -91,3 +91,11 @@ case class IsNotNull(child: Expression) extends UnaryExpression { override def toString = s"($child).isNotNull" } + +case class NumericIsNotNull(child: Expression) extends UnaryExpression { + def typeInfo = { + BasicTypeInfo.INT_TYPE_INFO + } + + override def toString = s"($child).numericIsNotNull" +} http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala index 7e9bc0d..7d7dc1c 100644 --- a/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala +++ b/flink-staging/flink-table/src/main/scala/org/apache/flink/api/table/runtime/ExpressionAggregateFunction.scala @@ -53,7 +53,10 @@ class ExpressionAggregateFunction( var i = 0 val len = functions.length while (i < len) { - functions(i).aggregate(current.productElement(fieldPositions(i))) + val element: Any = current.productElement(fieldPositions(i)) + if (element != null){ + functions(i).aggregate(element) + } i += 1 } } http://git-wip-us.apache.org/repos/asf/flink/blob/b59c81bc/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala ---------------------------------------------------------------------- diff --git a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala index 3b7ab8d..62ac345 100644 --- a/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala +++ b/flink-staging/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/AggregationsITCase.scala @@ -18,13 +18,16 @@ package org.apache.flink.api.scala.table.test -import org.apache.flink.api.table.ExpressionException +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.scala._ import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.api.table.typeinfo.RowTypeInfo +import org.apache.flink.api.table.{ExpressionException, Row} import org.apache.flink.core.fs.FileSystem.WriteMode -import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase} import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode +import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} +import org.junit.Assert._ import org.junit._ import org.junit.rules.TemporaryFolder import org.junit.runner.RunWith @@ -123,5 +126,56 @@ class AggregationsITCase(mode: TestExecutionMode) extends MultipleProgramsTestBa expected = "" } + @Test + def testAggregationWithNullValues(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val dataSet = env.fromElements[(Integer, String)]( + (123, "a"), (234, "b"), (345, "c"), (0, "d")) + + implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo( + Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name")) + + val rowDataSet = dataSet.map { + entry => + val row = new Row(2) + val amount = if (entry._1 > 200) entry._1 else null + row.setField(0, amount) + row.setField(1, entry._2) + row + } + + val entries = rowDataSet.toTable.select('id.avg, 'id.sum, 'id.count).collect().head + val mean = entries.productElement(0).toString.toInt + val sum = entries.productElement(1).toString.toInt + val count = entries.productElement(2).toString.toInt + + assertEquals(4,count) + + val computedMean = sum / 2 + assertEquals(computedMean, mean) + } + + @Test + def testAggregationWhenAllValuesAreNull(): Unit = { + + val env = ExecutionEnvironment.getExecutionEnvironment + val dataSet = env.fromElements[(Integer, String)]( + (123, "a"), (234, "b"), (345, "c"), (0, "d")) + + implicit val rowInfo: TypeInformation[Row] = new RowTypeInfo( + Seq(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO), Seq("id", "name")) + + val rowDataSet = dataSet.map { + entry => + val row = new Row(2) + row.setField(0, null) + row.setField(1, entry._2) + row + } + + val entries = rowDataSet.toTable.select('id.max).collect().head.productElement(0) + assertEquals(entries, null) + } }