Repository: flink Updated Branches: refs/heads/master 417597fbf -> dd8ef550c
[FLINK-5767] [table] Add interface for user-defined aggregate functions and built-in aggregate functions. This closes #3354. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/dd8ef550 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/dd8ef550 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/dd8ef550 Branch: refs/heads/master Commit: dd8ef550cf4c590c5a84ba313c57e202d4df94f4 Parents: 417597f Author: Shaoxuan Wang <[email protected]> Authored: Sat Feb 18 17:18:38 2017 +0800 Committer: Fabian Hueske <[email protected]> Committed: Fri Feb 24 10:07:44 2017 +0100 ---------------------------------------------------------------------- .../table/functions/AggregateFunction.scala | 76 +++++ .../functions/aggfunctions/AvgAggFunction.scala | 289 +++++++++++++++++++ .../aggfunctions/CountAggFunction.scala | 57 ++++ .../functions/aggfunctions/MaxAggFunction.scala | 126 ++++++++ .../functions/aggfunctions/MinAggFunction.scala | 126 ++++++++ .../functions/aggfunctions/SumAggFunction.scala | 153 ++++++++++ .../utils/UserDefinedFunctionUtils.scala | 13 + .../aggfunctions/AggFunctionTestBase.scala | 98 +++++++ .../aggfunctions/AvgFunctionTest.scala | 186 ++++++++++++ .../aggfunctions/CountAggFunctionTest.scala | 36 +++ .../aggfunctions/MaxAggFunctionTest.scala | 188 ++++++++++++ .../aggfunctions/MinAggFunctionTest.scala | 188 ++++++++++++ .../aggfunctions/SumAggFunctionTest.scala | 147 ++++++++++ 13 files changed, 1683 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala new file mode 100644 index 0000000..e15a8c4 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/AggregateFunction.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions + +import java.util.{List => JList} + +/** + * Base class for User-Defined Aggregates. + * + * @tparam T the type of the aggregation result + */ +trait AggregateFunction[T] extends UserDefinedFunction { + /** + * Create and init the Accumulator for this [[AggregateFunction]]. + * + * @return the accumulator with the initial value + */ + def createAccumulator(): Accumulator + + /** + * Called every time when an aggregation result should be materialized. + * The returned value could be either an early and incomplete result + * (periodically emitted as data arrive) or the final result of the + * aggregation. + * + * @param accumulator the accumulator which contains the current + * aggregated results + * @return the aggregation result + */ + def getValue(accumulator: Accumulator): T + + /** + * Process the input values and update the provided accumulator instance. + * + * @param accumulator the accumulator which contains the current + * aggregated results + * @param input the input value (usually obtained from a new arrived data) + */ + def accumulate(accumulator: Accumulator, input: Any): Unit + + /** + * Merge a list of accumulator instances into one accumulator instance. + * + * @param accumulators the [[java.util.List]] of accumulators + * that will be merged + * @return the resulting accumulator + */ + def merge(accumulators: JList[Accumulator]): Accumulator +} + +/** + * Base class for aggregate Accumulator. The accumulator is used to keep the + * aggregated values which are needed to compute an aggregation result. + * The state of the function must be put into the accumulator. + * + * TODO: We have the plan to have the accumulator and return types of + * functions dynamically provided by the users. This needs the refactoring + * of the AggregateFunction interface with the code generation. We will remove + * the [[Accumulator]] once codeGen for UDAGG is completed (FLINK-5813). + */ +trait Accumulator http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala new file mode 100644 index 0000000..f4c0b7b --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.{BigDecimal, BigInteger} +import java.util.{List => JList} +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * Base class for built-in Integral Avg aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class IntegralAvgAggFunction[T] extends AggregateFunction[T] { + + /** The initial accumulator for Integral Avg aggregate function */ + class IntegralAvgAccumulator extends JTuple2[Long, Long] with Accumulator { + f0 = 0 //sum + f1 = 0 //count + } + + override def createAccumulator(): Accumulator = { + new IntegralAvgAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[Number].longValue() + val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] + accum.f0 += v + accum.f1 += 1 + } + } + + override def getValue(accumulator: Accumulator): T = { + val accum = accumulator.asInstanceOf[IntegralAvgAccumulator] + if (accum.f1 == 0) { + null.asInstanceOf[T] + } else { + resultTypeConvert(accum.f0 / accum.f1) + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[IntegralAvgAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[IntegralAvgAccumulator] + ret.f1 += a.f1 + ret.f0 += a.f0 + i += 1 + } + ret + } + + /** + * Convert the intermediate result to the expected aggregation result type + * + * @param value the intermediate result. We use a Long container to save + * the intermediate result to avoid the overflow by sum operation. + * @return the result value with the expected aggregation result type + */ + def resultTypeConvert(value: Long): T +} + +/** + * Built-in Byte Avg aggregate function + */ +class ByteAvgAggFunction extends IntegralAvgAggFunction[Byte] { + override def resultTypeConvert(value: Long): Byte = value.toByte +} + +/** + * Built-in Short Avg aggregate function + */ +class ShortAvgAggFunction extends IntegralAvgAggFunction[Short] { + override def resultTypeConvert(value: Long): Short = value.toShort +} + +/** + * Built-in Int Avg aggregate function + */ +class IntAvgAggFunction extends IntegralAvgAggFunction[Int] { + override def resultTypeConvert(value: Long): Int = value.toInt +} + +/** + * Base Class for Built-in Big Integral Avg aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class BigIntegralAvgAggFunction[T] extends AggregateFunction[T] { + + /** The initial accumulator for Big Integral Avg aggregate function */ + class BigIntegralAvgAccumulator + extends JTuple2[BigInteger, Long] with Accumulator { + f0 = BigInteger.ZERO //sum + f1 = 0 //count + } + + override def createAccumulator(): Accumulator = { + new BigIntegralAvgAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[Long] + val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] + a.f0 = a.f0.add(BigInteger.valueOf(v)) + a.f1 += 1 + } + } + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[BigIntegralAvgAccumulator] + if (a.f1 == 0) { + null.asInstanceOf[T] + } else { + resultTypeConvert(a.f0.divide(BigInteger.valueOf(a.f1))) + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[BigIntegralAvgAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[BigIntegralAvgAccumulator] + ret.f1 += a.f1 + ret.f0 = ret.f0.add(a.f0) + i += 1 + } + ret + } + + /** + * Convert the intermediate result to the expected aggregation result type + * + * @param value the intermediate result. We use a BigInteger container to + * save the intermediate result to avoid the overflow by sum + * operation. + * @return the result value with the expected aggregation result type + */ + def resultTypeConvert(value: BigInteger): T +} + +/** + * Built-in Long Avg aggregate function + */ +class LongAvgAggFunction extends BigIntegralAvgAggFunction[Long] { + override def resultTypeConvert(value: BigInteger): Long = value.longValue() +} + +/** + * Base class for built-in Floating Avg aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class FloatingAvgAggFunction[T] extends AggregateFunction[T] { + + /** The initial accumulator for Floating Avg aggregate function */ + class FloatingAvgAccumulator extends JTuple2[Double, Long] with Accumulator { + f0 = 0 //sum + f1 = 0 //count + } + + override def createAccumulator(): Accumulator = { + new FloatingAvgAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[Number].doubleValue() + val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] + accum.f0 += v + accum.f1 += 1 + } + } + + override def getValue(accumulator: Accumulator): T = { + val accum = accumulator.asInstanceOf[FloatingAvgAccumulator] + if (accum.f1 == 0) { + null.asInstanceOf[T] + } else { + resultTypeConvert(accum.f0 / accum.f1) + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[FloatingAvgAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[FloatingAvgAccumulator] + ret.f1 += a.f1 + ret.f0 += a.f0 + i += 1 + } + ret + } + + /** + * Convert the intermediate result to the expected aggregation result type + * + * @param value the intermediate result. We use a Double container to save + * the intermediate result to avoid the overflow by sum operation. + * @return the result value with the expected aggregation result type + */ + def resultTypeConvert(value: Double): T +} + +/** + * Built-in Float Avg aggregate function + */ +class FloatAvgAggFunction extends FloatingAvgAggFunction[Float] { + override def resultTypeConvert(value: Double): Float = value.toFloat +} + +/** + * Built-in Int Double aggregate function + */ +class DoubleAvgAggFunction extends FloatingAvgAggFunction[Double] { + override def resultTypeConvert(value: Double): Double = value +} + +/** + * Base class for built-in Big Decimal Avg aggregate function + */ +class DecimalAvgAggFunction extends AggregateFunction[BigDecimal] { + + /** The initial accumulator for Big Decimal Avg aggregate function */ + class DecimalAvgAccumulator + extends JTuple2[BigDecimal, Long] with Accumulator { + f0 = BigDecimal.ZERO //sum + f1 = 0 //count + } + + override def createAccumulator(): Accumulator = { + new DecimalAvgAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[BigDecimal] + val accum = accumulator.asInstanceOf[DecimalAvgAccumulator] + if (accum.f1 == 0) { + accum.f0 = v + } else { + accum.f0 = accum.f0.add(v) + } + accum.f1 += 1 + } + } + + override def getValue(accumulator: Accumulator): BigDecimal = { + val a = accumulator.asInstanceOf[DecimalAvgAccumulator] + if (a.f1 == 0) { + null.asInstanceOf[BigDecimal] + } else { + a.f0.divide(BigDecimal.valueOf(a.f1)) + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[DecimalAvgAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[DecimalAvgAccumulator] + ret.f0 = ret.f0.add(a.f0) + ret.f1 += a.f1 + i += 1 + } + ret + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala new file mode 100644 index 0000000..8b903d1 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunction.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.util.{List => JList} +import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * built-in count aggregate function + */ +class CountAggFunction extends AggregateFunction[Long] { + + /** The initial accumulator for count aggregate function */ + class CountAccumulator extends JTuple1[Long] with Accumulator { + f0 = 0 //count + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + accumulator.asInstanceOf[CountAccumulator].f0 += 1 + } + } + + override def getValue(accumulator: Accumulator): Long = { + accumulator.asInstanceOf[CountAccumulator].f0 + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[CountAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + ret.f0 += accumulators.get(i).asInstanceOf[CountAccumulator].f0 + i += 1 + } + ret + } + + override def createAccumulator(): Accumulator = { + new CountAccumulator + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala new file mode 100644 index 0000000..20041ee --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunction.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import java.util.{List => JList} +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * Base class for built-in Max aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class MaxAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { + + /** The initial accumulator for Max aggregate function */ + class MaxAccumulator extends JTuple2[T, Boolean] with Accumulator { + f0 = 0.asInstanceOf[T] //max + f1 = false + } + + override def createAccumulator(): Accumulator = { + new MaxAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[T] + val a = accumulator.asInstanceOf[MaxAccumulator] + if (!a.f1 || ord.compare(a.f0, v) < 0) { + a.f0 = v + a.f1 = true + } + } + } + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[MaxAccumulator] + if (a.f1) { + a.f0 + } else { + null.asInstanceOf[T] + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0) + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[MaxAccumulator] + if (a.f1) { + accumulate(ret.asInstanceOf[MaxAccumulator], a.f0) + } + i += 1 + } + ret + } +} + +/** + * Built-in Byte Max aggregate function + */ +class ByteMaxAggFunction extends MaxAggFunction[Byte] + +/** + * Built-in Short Max aggregate function + */ +class ShortMaxAggFunction extends MaxAggFunction[Short] + +/** + * Built-in Int Max aggregate function + */ +class IntMaxAggFunction extends MaxAggFunction[Int] + +/** + * Built-in Long Max aggregate function + */ +class LongMaxAggFunction extends MaxAggFunction[Long] + +/** + * Built-in Float Max aggregate function + */ +class FloatMaxAggFunction extends MaxAggFunction[Float] + +/** + * Built-in Double Max aggregate function + */ +class DoubleMaxAggFunction extends MaxAggFunction[Double] + +/** + * Built-in Boolean Max aggregate function + */ +class BooleanMaxAggFunction extends MaxAggFunction[Boolean] + +/** + * Built-in Big Decimal Max aggregate function + */ +class DecimalMaxAggFunction extends MaxAggFunction[BigDecimal] { + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[BigDecimal] + val accum = accumulator.asInstanceOf[MaxAccumulator] + if (!accum.f1 || accum.f0.compareTo(v) < 0) { + accum.f0 = v + accum.f1 = true + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala new file mode 100644 index 0000000..16461ae --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunction.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import java.util.{List => JList} +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * Base class for built-in Min aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class MinAggFunction[T](implicit ord: Ordering[T]) extends AggregateFunction[T] { + + /** The initial accumulator for Min aggregate function */ + class MinAccumulator extends JTuple2[T, Boolean] with Accumulator { + f0 = 0.asInstanceOf[T] //min + f1 = false + } + + override def createAccumulator(): Accumulator = { + new MinAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[T] + val a = accumulator.asInstanceOf[MinAccumulator] + if (!a.f1 || ord.compare(a.f0, v) > 0) { + a.f0 = v + a.f1 = true + } + } + } + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[MinAccumulator] + if (a.f1) { + a.f0 + } else { + null.asInstanceOf[T] + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0) + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[MinAccumulator] + if (a.f1) { + accumulate(ret.asInstanceOf[MinAccumulator], a.f0) + } + i += 1 + } + ret + } +} + +/** + * Built-in Byte Min aggregate function + */ +class ByteMinAggFunction extends MinAggFunction[Byte] + +/** + * Built-in Short Min aggregate function + */ +class ShortMinAggFunction extends MinAggFunction[Short] + +/** + * Built-in Int Min aggregate function + */ +class IntMinAggFunction extends MinAggFunction[Int] + +/** + * Built-in Long Min aggregate function + */ +class LongMinAggFunction extends MinAggFunction[Long] + +/** + * Built-in Float Min aggregate function + */ +class FloatMinAggFunction extends MinAggFunction[Float] + +/** + * Built-in Double Min aggregate function + */ +class DoubleMinAggFunction extends MinAggFunction[Double] + +/** + * Built-in Boolean Min aggregate function + */ +class BooleanMinAggFunction extends MinAggFunction[Boolean] + +/** + * Built-in Big Decimal Min aggregate function + */ +class DecimalMinAggFunction extends MinAggFunction[BigDecimal] { + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[BigDecimal] + val accum = accumulator.asInstanceOf[MinAccumulator] + if (!accum.f1 || accum.f0.compareTo(v) > 0) { + accum.f0 = v + accum.f1 = true + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala new file mode 100644 index 0000000..b04d8c0 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunction.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import java.util.{List => JList} +import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} + +/** + * Base class for built-in Sum aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class SumAggFunction[T: Numeric] extends AggregateFunction[T] { + + /** The initial accumulator for Sum aggregate function */ + class SumAccumulator extends JTuple2[T, Boolean] with Accumulator { + f0 = numeric.zero //sum + f1 = false + } + + private val numeric = implicitly[Numeric[T]] + + override def createAccumulator(): Accumulator = { + new SumAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[T] + val a = accumulator.asInstanceOf[SumAccumulator] + a.f0 = numeric.plus(v, a.f0) + a.f1 = true + } + } + + override def getValue(accumulator: Accumulator): T = { + val a = accumulator.asInstanceOf[SumAccumulator] + if (a.f1) { + a.f0 + } else { + null.asInstanceOf[T] + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = createAccumulator().asInstanceOf[SumAccumulator] + var i: Int = 0 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[SumAccumulator] + if (a.f1) { + ret.f0 = numeric.plus(ret.f0, a.f0) + ret.f1 = true + } + i += 1 + } + ret + } +} + +/** + * Built-in Byte Sum aggregate function + */ +class ByteSumAggFunction extends SumAggFunction[Byte] + +/** + * Built-in Short Sum aggregate function + */ +class ShortSumAggFunction extends SumAggFunction[Short] + +/** + * Built-in Int Sum aggregate function + */ +class IntSumAggFunction extends SumAggFunction[Int] + +/** + * Built-in Long Sum aggregate function + */ +class LongSumAggFunction extends SumAggFunction[Long] + +/** + * Built-in Float Sum aggregate function + */ +class FloatSumAggFunction extends SumAggFunction[Float] + +/** + * Built-in Double Sum aggregate function + */ +class DoubleSumAggFunction extends SumAggFunction[Double] + + +/** + * Built-in Big Decimal Sum aggregate function + */ +class DecimalSumAggFunction extends AggregateFunction[BigDecimal] { + + /** The initial accumulator for Big Decimal Sum aggregate function */ + class DecimalSumAccumulator extends JTuple2[BigDecimal, Boolean] with Accumulator { + f0 = BigDecimal.ZERO + f1 = false + } + + override def createAccumulator(): Accumulator = { + new DecimalSumAccumulator + } + + override def accumulate(accumulator: Accumulator, value: Any): Unit = { + if (value != null) { + val v = value.asInstanceOf[BigDecimal] + val accum = accumulator.asInstanceOf[DecimalSumAccumulator] + accum.f0 = accum.f0.add(v) + accum.f1 = true + } + } + + override def getValue(accumulator: Accumulator): BigDecimal = { + if (!accumulator.asInstanceOf[DecimalSumAccumulator].f1) { + null.asInstanceOf[BigDecimal] + } else { + accumulator.asInstanceOf[DecimalSumAccumulator].f0 + } + } + + override def merge(accumulators: JList[Accumulator]): Accumulator = { + val ret = accumulators.get(0).asInstanceOf[DecimalSumAccumulator] + var i: Int = 1 + while (i < accumulators.size()) { + val a = accumulators.get(i).asInstanceOf[DecimalSumAccumulator] + if (a.f1) { + accumulate(ret, a.f0) + ret.f1 = true + } + i += 1 + } + ret + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala index 16a6717b..aec4fbb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala @@ -119,6 +119,19 @@ object UserDefinedFunctionUtils { } /** + * Check if a given method exits in the given function + */ + def ifMethodExitInFunction(method: String, function: UserDefinedFunction): Boolean = { + val methods = function + .getClass + .getMethods + .filter { + m => m.getName == method + } + !methods.isEmpty + } + + /** * Extracts "eval" methods and throws a [[ValidationException]] if no implementation * can be found. */ http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala new file mode 100644 index 0000000..627b25b --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AggFunctionTestBase.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import java.util.{ArrayList => JArrayList, List => JList} +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ +import org.junit.Assert.assertEquals +import org.junit.Test + +/** + * Base class for aggregate function test + * + * @tparam T the type for the aggregation result + */ +abstract class AggFunctionTestBase[T] { + def inputValueSets: Seq[Seq[_]] + + def expectedResults: Seq[T] + + def aggregator: AggregateFunction[T] + + @Test + // test aggregate functions without partial merge + def testAggregateWithoutMerge(): Unit = { + // iterate over input sets + for ((vals, expected) <- inputValueSets.zip(expectedResults)) { + val accumulator = aggregateVals(vals) + val result = aggregator.getValue(accumulator) + validateResult(expected, result) + } + } + + @Test + // test aggregate functions with partial merge + def testAggregateWithMerge(): Unit = { + + if (ifMethodExitInFunction("merge", aggregator)) { + // iterate over input sets + for ((vals, expected) <- inputValueSets.zip(expectedResults)) { + //equally split the vals sequence into two sequences + val (firstVals, secondVals) = vals.splitAt(vals.length / 2) + + val accumulators: JList[Accumulator] = new JArrayList[Accumulator]() + accumulators.add(aggregateVals(firstVals)) + accumulators.add(aggregateVals(secondVals)) + + val accumulator = aggregator.merge(accumulators) + val result = aggregator.getValue(accumulator) + validateResult(expected, result) + } + + // iterate over input sets + for ((vals, expected) <- inputValueSets.zip(expectedResults)) { + //test partial merge with an empty accumulator + val accumulators: JList[Accumulator] = new JArrayList[Accumulator]() + accumulators.add(aggregateVals(vals)) + accumulators.add(aggregator.createAccumulator()) + + val accumulator = aggregator.merge(accumulators) + val result = aggregator.getValue(accumulator) + validateResult(expected, result) + } + } + } + + private def validateResult(expected: T, result: T): Unit = { + (expected, result) match { + case (e: BigDecimal, r: BigDecimal) => + // BigDecimal.equals() value and scale but we are only interested in value. + assert(e.compareTo(r) == 0) + case _ => + assertEquals(expected, result) + } + } + + private def aggregateVals(vals: Seq[_]): Accumulator = { + val accumulator = aggregator.createAccumulator() + vals.foreach(v => aggregator.accumulate(accumulator, v)) + accumulator + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AvgFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AvgFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AvgFunctionTest.scala new file mode 100644 index 0000000..a388acf --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/AvgFunctionTest.scala @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import org.apache.flink.table.functions.AggregateFunction + +/** + * Test case for built-in average aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class AvgAggFunctionTestBase[T: Numeric] extends AggFunctionTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + minVal, + minVal, + null.asInstanceOf[T], + minVal, + minVal, + null.asInstanceOf[T], + minVal, + minVal, + minVal + ), + Seq( + maxVal, + maxVal, + null.asInstanceOf[T], + maxVal, + maxVal, + null.asInstanceOf[T], + maxVal, + maxVal, + maxVal + ), + Seq( + minVal, + maxVal, + null.asInstanceOf[T], + numeric.fromInt(0), + numeric.negate(maxVal), + numeric.negate(minVal), + null.asInstanceOf[T] + ), + Seq( + numeric.fromInt(1), + numeric.fromInt(2), + null.asInstanceOf[T], + numeric.fromInt(3), + numeric.fromInt(4), + numeric.fromInt(5), + null.asInstanceOf[T] + ), + Seq( + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq( + minVal, + maxVal, + numeric.fromInt(0), + numeric.fromInt(3), + null.asInstanceOf[T] + ) +} + +class ByteAvgAggFunctionTest extends AvgAggFunctionTestBase[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator = new ByteAvgAggFunction() +} + +class ShortAvgAggFunctionTest extends AvgAggFunctionTestBase[Short] { + + override def minVal = (Short.MinValue + 1).toShort + + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator = new ShortAvgAggFunction() +} + +class IntAvgAggFunctionTest extends AvgAggFunctionTestBase[Int] { + + override def minVal = Int.MinValue + 1 + + override def maxVal = Int.MaxValue - 1 + + override def aggregator = new IntAvgAggFunction() +} + +class LongAvgAggFunctionTest extends AvgAggFunctionTestBase[Long] { + + override def minVal = Long.MinValue + 1 + + override def maxVal = Long.MaxValue - 1 + + override def aggregator = new LongAvgAggFunction() +} + +class FloatAvgAggFunctionTest extends AvgAggFunctionTestBase[Float] { + + override def minVal = Float.MinValue + + override def maxVal = Float.MaxValue + + override def aggregator = new FloatAvgAggFunction() +} + +class DoubleAvgAggFunctionTest extends AvgAggFunctionTestBase[Double] { + + override def minVal = Float.MinValue + + override def maxVal = Float.MaxValue + + override def aggregator = new DoubleAvgAggFunction() +} + +class DecimalAvgAggFunctionTest extends AggFunctionTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("987654321000000"), + new BigDecimal("-0.000000000012345"), + null, + new BigDecimal("0.000000000012345"), + new BigDecimal("-987654321000000"), + null, + new BigDecimal("0") + ), + Seq( + new BigDecimal("987654321000000"), + new BigDecimal("-0.000000000012345"), + null, + new BigDecimal("0.000000000012345"), + new BigDecimal("-987654321000000"), + null, + new BigDecimal("5") + ), + Seq( + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + BigDecimal.ZERO, + BigDecimal.ONE, + null + ) + + override def aggregator: AggregateFunction[BigDecimal] = new DecimalAvgAggFunction() +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunctionTest.scala new file mode 100644 index 0000000..d5f09b2 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/CountAggFunctionTest.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.aggfunctions + +import org.apache.flink.table.functions.AggregateFunction + +/** + * Test case for built-in count aggregate function + */ +class CountAggFunctionTest extends AggFunctionTestBase[Long] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq("a", "b", null, "c", null, "d", "e", null, "f"), + Seq(null, null, null, null, null, null) + ) + + override def expectedResults: Seq[Long] = Seq(6L, 0L) + + override def aggregator: AggregateFunction[Long] = new CountAggFunction() +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionTest.scala new file mode 100644 index 0000000..91cbeea --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionTest.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import org.apache.flink.table.functions.AggregateFunction + +/** + * Test case for built-in max aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class MaxAggFunctionTest[T: Numeric] extends AggFunctionTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + numeric.fromInt(1), + null.asInstanceOf[T], + maxVal, + numeric.fromInt(-99), + numeric.fromInt(3), + numeric.fromInt(56), + numeric.fromInt(0), + minVal, + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T] + ), + Seq( + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq( + maxVal, + null.asInstanceOf[T] + ) +} + +class ByteMaxAggFunctionTest extends MaxAggFunctionTest[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator: AggregateFunction[Byte] = new ByteMaxAggFunction() +} + +class ShortMaxAggFunctionTest extends MaxAggFunctionTest[Short] { + + override def minVal = (Short.MinValue + 1).toShort + + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator: AggregateFunction[Short] = new ShortMaxAggFunction() +} + +class IntMaxAggFunctionTest extends MaxAggFunctionTest[Int] { + + override def minVal = Int.MinValue + 1 + + override def maxVal = Int.MaxValue - 1 + + override def aggregator: AggregateFunction[Int] = new IntMaxAggFunction() +} + +class LongMaxAggFunctionTest extends MaxAggFunctionTest[Long] { + + override def minVal = Long.MinValue + 1 + + override def maxVal = Long.MaxValue - 1 + + override def aggregator: AggregateFunction[Long] = new LongMaxAggFunction() +} + +class FloatMaxAggFunctionTest extends MaxAggFunctionTest[Float] { + + override def minVal = Float.MinValue / 2 + + override def maxVal = Float.MaxValue / 2 + + override def aggregator: AggregateFunction[Float] = new FloatMaxAggFunction() +} + +class DoubleMaxAggFunctionTest extends MaxAggFunctionTest[Double] { + + override def minVal = Double.MinValue / 2 + + override def maxVal = Double.MaxValue / 2 + + override def aggregator: AggregateFunction[Double] = new DoubleMaxAggFunction() +} + +class BooleanMaxAggFunctionTest extends AggFunctionTestBase[Boolean] { + + override def inputValueSets: Seq[Seq[Boolean]] = Seq( + Seq( + false, + false, + false + ), + Seq( + true, + true, + true + ), + Seq( + true, + false, + null.asInstanceOf[Boolean], + true, + false, + true, + null.asInstanceOf[Boolean] + ), + Seq( + null.asInstanceOf[Boolean], + null.asInstanceOf[Boolean], + null.asInstanceOf[Boolean] + ) + ) + + override def expectedResults: Seq[Boolean] = Seq( + false, + true, + true, + null.asInstanceOf[Boolean] + ) + + override def aggregator: AggregateFunction[Boolean] = new BooleanMaxAggFunction() +} + +class DecimalMaxAggFunctionTest extends AggFunctionTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("1000.000001"), + new BigDecimal("-1"), + new BigDecimal("-999.998999"), + null, + new BigDecimal("0"), + new BigDecimal("-999.999"), + null, + new BigDecimal("999.999") + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("1000.000001"), + null + ) + + override def aggregator: AggregateFunction[BigDecimal] = new DecimalMaxAggFunction() +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionTest.scala new file mode 100644 index 0000000..6a6e5b9 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionTest.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import org.apache.flink.table.functions.AggregateFunction + +/** + * Test case for built-in max aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class MinAggFunctionTest[T: Numeric] extends AggFunctionTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def minVal: T + + def maxVal: T + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + numeric.fromInt(1), + null.asInstanceOf[T], + maxVal, + numeric.fromInt(-99), + numeric.fromInt(3), + numeric.fromInt(56), + numeric.fromInt(0), + minVal, + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T] + ), + Seq( + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq( + minVal, + null.asInstanceOf[T] + ) +} + +class ByteMinAggFunctionTest extends MinAggFunctionTest[Byte] { + + override def minVal = (Byte.MinValue + 1).toByte + + override def maxVal = (Byte.MaxValue - 1).toByte + + override def aggregator: AggregateFunction[Byte] = new ByteMinAggFunction() +} + +class ShortMinAggFunctionTest extends MinAggFunctionTest[Short] { + + override def minVal = (Short.MinValue + 1).toShort + + override def maxVal = (Short.MaxValue - 1).toShort + + override def aggregator: AggregateFunction[Short] = new ShortMinAggFunction() +} + +class IntMinAggFunctionTest extends MinAggFunctionTest[Int] { + + override def minVal = Int.MinValue + 1 + + override def maxVal = Int.MaxValue - 1 + + override def aggregator: AggregateFunction[Int] = new IntMinAggFunction() +} + +class LongMinAggFunctionTest extends MinAggFunctionTest[Long] { + + override def minVal = Long.MinValue + 1 + + override def maxVal = Long.MaxValue - 1 + + override def aggregator: AggregateFunction[Long] = new LongMinAggFunction() +} + +class FloatMinAggFunctionTest extends MinAggFunctionTest[Float] { + + override def minVal = Float.MinValue / 2 + + override def maxVal = Float.MaxValue / 2 + + override def aggregator: AggregateFunction[Float] = new FloatMinAggFunction() +} + +class DoubleMinAggFunctionTest extends MinAggFunctionTest[Double] { + + override def minVal = Double.MinValue / 2 + + override def maxVal = Double.MaxValue / 2 + + override def aggregator: AggregateFunction[Double] = new DoubleMinAggFunction() +} + +class BooleanMinAggFunctionTest extends AggFunctionTestBase[Boolean] { + + override def inputValueSets: Seq[Seq[Boolean]] = Seq( + Seq( + false, + false, + false + ), + Seq( + true, + true, + true + ), + Seq( + true, + false, + null.asInstanceOf[Boolean], + true, + false, + true, + null.asInstanceOf[Boolean] + ), + Seq( + null.asInstanceOf[Boolean], + null.asInstanceOf[Boolean], + null.asInstanceOf[Boolean] + ) + ) + + override def expectedResults: Seq[Boolean] = Seq( + false, + true, + false, + null.asInstanceOf[Boolean] + ) + + override def aggregator: AggregateFunction[Boolean] = new BooleanMinAggFunction() +} + +class DecimalMinAggFunctionTest extends AggFunctionTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("1000"), + new BigDecimal("-1"), + new BigDecimal("-999.998999"), + null, + new BigDecimal("0"), + new BigDecimal("-999.999"), + null, + new BigDecimal("999.999") + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("-999.999"), + null + ) + + override def aggregator: AggregateFunction[BigDecimal] = new DecimalMinAggFunction() +} http://git-wip-us.apache.org/repos/asf/flink/blob/dd8ef550/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunctionTest.scala new file mode 100644 index 0000000..95feddd --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/functions/aggfunctions/SumAggFunctionTest.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.functions.aggfunctions + +import java.math.BigDecimal +import org.apache.flink.table.functions.AggregateFunction + +/** + * Test case for built-in sum aggregate function + * + * @tparam T the type for the aggregation result + */ +abstract class SumAggFunctionTestBase[T: Numeric] extends AggFunctionTestBase[T] { + + private val numeric: Numeric[T] = implicitly[Numeric[T]] + + def maxVal: T + + private val minVal = numeric.negate(maxVal) + + override def inputValueSets: Seq[Seq[T]] = Seq( + Seq( + minVal, + numeric.fromInt(1), + null.asInstanceOf[T], + numeric.fromInt(2), + numeric.fromInt(3), + numeric.fromInt(4), + numeric.fromInt(5), + numeric.fromInt(-10), + numeric.fromInt(-20), + numeric.fromInt(17), + null.asInstanceOf[T], + maxVal + ), + Seq( + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T], + null.asInstanceOf[T] + ) + ) + + override def expectedResults: Seq[T] = Seq( + numeric.fromInt(2), + null.asInstanceOf[T] + ) +} + +class ByteSumAggFunctionTest extends SumAggFunctionTestBase[Byte] { + + override def maxVal = (Byte.MaxValue / 2).toByte + + override def aggregator: AggregateFunction[Byte] = new ByteSumAggFunction +} + +class ShortSumAggFunctionTest extends SumAggFunctionTestBase[Short] { + + override def maxVal = (Short.MaxValue / 2).toShort + + override def aggregator: AggregateFunction[Short] = new ShortSumAggFunction +} + +class IntSumAggFunctionTest extends SumAggFunctionTestBase[Int] { + + override def maxVal = Int.MaxValue / 2 + + override def aggregator: AggregateFunction[Int] = new IntSumAggFunction +} + +class LongSumAggFunctionTest extends SumAggFunctionTestBase[Long] { + + override def maxVal = Long.MaxValue / 2 + + override def aggregator: AggregateFunction[Long] = new LongSumAggFunction +} + +class FloatSumAggFunctionTest extends SumAggFunctionTestBase[Float] { + + override def maxVal = 12345.6789f + + override def aggregator: AggregateFunction[Float] = new FloatSumAggFunction +} + +class DoubleSumAggFunctionTest extends SumAggFunctionTestBase[Double] { + + override def maxVal = 12345.6789d + + override def aggregator: AggregateFunction[Double] = new DoubleSumAggFunction +} + + +class DecimalSumAggFunctionTest extends AggFunctionTestBase[BigDecimal] { + + override def inputValueSets: Seq[Seq[_]] = Seq( + Seq( + new BigDecimal("1"), + new BigDecimal("2"), + new BigDecimal("3"), + null, + new BigDecimal("0"), + new BigDecimal("-1000"), + new BigDecimal("0.000000000002"), + new BigDecimal("1000"), + new BigDecimal("-0.000000000001"), + new BigDecimal("999.999"), + null, + new BigDecimal("4"), + new BigDecimal("-999.999"), + null + ), + Seq( + null, + null, + null, + null, + null + ) + ) + + override def expectedResults: Seq[BigDecimal] = Seq( + new BigDecimal("10.000000000001"), + null + ) + + override def aggregator: AggregateFunction[BigDecimal] = new DecimalSumAggFunction() +} + +
