asfgit closed pull request #7201: [FLINK-7208] [table] Optimize Min/MaxWithRetractAggFunction with DataView URL: https://github.com/apache/flink/pull/7201
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala index cbb395e232f..fde93ff1679 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MaxAggFunctionWithRetract.scala @@ -18,19 +18,20 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{HashMap => JHashMap} -import java.lang.{Iterable => JIterable} +import java.lang.{Iterable => JIterable, Long => JLong} import java.sql.{Date, Time, Timestamp} -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo} -import org.apache.flink.table.api.Types +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types} +import org.apache.flink.table.api.dataview.MapView import org.apache.flink.table.functions.aggfunctions.Ordering._ import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Max with retraction aggregate function */ -class MaxWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] +class MaxWithRetractAccumulator[T] { + var max: T = _ + var distinctCount: JLong = _ + var map: MapView[T, JLong] = _ +} /** * Base class for built-in Max with retraction aggregate function @@ -42,8 +43,10 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) override def createAccumulator(): MaxWithRetractAccumulator[T] = { val acc = new MaxWithRetractAccumulator[T] - acc.f0 = getInitValue //max - acc.f1 = new JHashMap[T, Long]() //store the count for each value + acc.max = getInitValue //max + acc.distinctCount = 0L + acc.map = new MapView(getValueTypeInfo, Types.LONG) + .asInstanceOf[MapView[T, JLong]] //store the count for each value acc } @@ -51,16 +54,17 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) if (value != null) { val v = value.asInstanceOf[T] - if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) < 0)) { - acc.f0 = v + if (acc.distinctCount == 0 || (ord.compare(acc.max, v) < 0)) { + acc.max = v } - if (!acc.f1.containsKey(v)) { - acc.f1.put(v, 1L) + var count = acc.map.get(v) + if (count == null) { + acc.map.put(v, 1L) + acc.distinctCount += 1 } else { - var count = acc.f1.get(v) count += 1L - acc.f1.put(v, count) + acc.map.put(v, count) } } } @@ -69,39 +73,45 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) if (value != null) { val v = value.asInstanceOf[T] - var count = acc.f1.get(v) - count -= 1L - if (count == 0) { + val count = acc.map.get(v) + if (count == null || count == 1) { //remove the key v from the map if the number of appearance of the value v is 0 - acc.f1.remove(v) + if (count != null) { + acc.map.remove(v) + } //if the total count is 0, we could just simply set the f0(max) to the initial value - if (acc.f1.size() == 0) { - acc.f0 = getInitValue + acc.distinctCount -= 1 + if (acc.distinctCount == 0) { + acc.max = getInitValue return } //if v is the current max value, we have to iterate the map to find the 2nd biggest // value to replace v as the max value - if (v == acc.f0) { - val iterator = acc.f1.keySet().iterator() - var key = iterator.next() - acc.f0 = key + if (v == acc.max) { + val iterator = acc.map.keys.iterator() + var hasMax = false while (iterator.hasNext) { - key = iterator.next() - if (ord.compare(acc.f0, key) < 0) { - acc.f0 = key + val key = iterator.next() + if (!hasMax || ord.compare(acc.max, key) < 0) { + acc.max = key + hasMax = true } } + + if (!hasMax) { + acc.distinctCount = 0L + } } } else { - acc.f1.put(v, count) + acc.map.put(v, count - 1) } } } override def getValue(acc: MaxWithRetractAccumulator[T]): T = { - if (acc.f1.size() != 0) { - acc.f0 + if (acc.distinctCount != 0) { + acc.max } else { null.asInstanceOf[T] } @@ -112,19 +122,23 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) val iter = its.iterator() while (iter.hasNext) { val a = iter.next() - if (a.f1.size() != 0) { + if (a.distinctCount != 0) { // set max element - if (ord.compare(acc.f0, a.f0) < 0) { - acc.f0 = a.f0 + if (ord.compare(acc.max, a.max) < 0) { + acc.max = a.max } // merge the count for each key - val iterator = a.f1.keySet().iterator() + val iterator = a.map.entries.iterator() while (iterator.hasNext) { - val key = iterator.next() - if (acc.f1.containsKey(key)) { - acc.f1.put(key, acc.f1.get(key) + a.f1.get(key)) + val entry = iterator.next() + val key = entry.getKey + val value = entry.getValue + val count = acc.map.get(key) + if (count != null) { + acc.map.put(key, count + value) } else { - acc.f1.put(key, a.f1.get(key)) + acc.map.put(key, value) + acc.distinctCount += 1 } } } @@ -132,15 +146,9 @@ abstract class MaxWithRetractAggFunction[T](implicit ord: Ordering[T]) } def resetAccumulator(acc: MaxWithRetractAccumulator[T]): Unit = { - acc.f0 = getInitValue - acc.f1.clear() - } - - override def getAccumulatorType: TypeInformation[MaxWithRetractAccumulator[T]] = { - new TupleTypeInfo( - classOf[MaxWithRetractAccumulator[T]], - getValueTypeInfo, - new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO)) + acc.max = getInitValue + acc.distinctCount = 0L + acc.map.clear() } def getInitValue: T diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala index 480f836e849..f62f2eceb29 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/MinAggFunctionWithRetract.scala @@ -18,19 +18,20 @@ package org.apache.flink.table.functions.aggfunctions import java.math.BigDecimal -import java.util.{HashMap => JHashMap} -import java.lang.{Iterable => JIterable} +import java.lang.{Iterable => JIterable, Long => JLong} import java.sql.{Date, Time, Timestamp} -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} -import org.apache.flink.api.java.typeutils.{MapTypeInfo, TupleTypeInfo} -import org.apache.flink.table.api.Types +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation, Types} +import org.apache.flink.table.api.dataview.MapView import org.apache.flink.table.functions.aggfunctions.Ordering._ import org.apache.flink.table.functions.AggregateFunction /** The initial accumulator for Min with retraction aggregate function */ -class MinWithRetractAccumulator[T] extends JTuple2[T, JHashMap[T, Long]] +class MinWithRetractAccumulator[T] { + var min: T = _ + var distinctCount: JLong = _ + var map: MapView[T, JLong] = _ +} /** * Base class for built-in Min with retraction aggregate function @@ -42,8 +43,10 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) override def createAccumulator(): MinWithRetractAccumulator[T] = { val acc = new MinWithRetractAccumulator[T] - acc.f0 = getInitValue //min - acc.f1 = new JHashMap[T, Long]() //store the count for each value + acc.min = getInitValue //min + acc.distinctCount = 0L + acc.map = new MapView(getValueTypeInfo, Types.LONG) + .asInstanceOf[MapView[T, JLong]] //store the count for each value acc } @@ -51,16 +54,17 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) if (value != null) { val v = value.asInstanceOf[T] - if (acc.f1.size() == 0 || (ord.compare(acc.f0, v) > 0)) { - acc.f0 = v + if (acc.distinctCount == 0 || (ord.compare(acc.min, v) > 0)) { + acc.min = v } - if (!acc.f1.containsKey(v)) { - acc.f1.put(v, 1L) + var count = acc.map.get(v) + if (count == null) { + acc.map.put(v, 1L) + acc.distinctCount += 1 } else { - var count = acc.f1.get(v) count += 1L - acc.f1.put(v, count) + acc.map.put(v, count) } } } @@ -69,39 +73,45 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) if (value != null) { val v = value.asInstanceOf[T] - var count = acc.f1.get(v) - count -= 1L - if (count == 0) { + val count = acc.map.get(v) + if (count == null || count == 1) { //remove the key v from the map if the number of appearance of the value v is 0 - acc.f1.remove(v) + if (count != null) { + acc.map.remove(v) + } //if the total count is 0, we could just simply set the f0(min) to the initial value - if (acc.f1.size() == 0) { - acc.f0 = getInitValue + acc.distinctCount -= 1 + if (acc.distinctCount == 0) { + acc.min = getInitValue return } //if v is the current min value, we have to iterate the map to find the 2nd smallest // value to replace v as the min value - if (v == acc.f0) { - val iterator = acc.f1.keySet().iterator() - var key = iterator.next() - acc.f0 = key + if (v == acc.min) { + val iterator = acc.map.keys.iterator() + var hasMin = false while (iterator.hasNext) { - key = iterator.next() - if (ord.compare(acc.f0, key) > 0) { - acc.f0 = key + val key = iterator.next() + if (!hasMin || ord.compare(acc.min, key) > 0) { + acc.min = key + hasMin = true } } + + if (!hasMin) { + acc.distinctCount = 0L + } } } else { - acc.f1.put(v, count) + acc.map.put(v, count - 1) } } } override def getValue(acc: MinWithRetractAccumulator[T]): T = { - if (acc.f1.size() != 0) { - acc.f0 + if (acc.distinctCount != 0) { + acc.min } else { null.asInstanceOf[T] } @@ -112,19 +122,23 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) val iter = its.iterator() while (iter.hasNext) { val a = iter.next() - if (a.f1.size() != 0) { + if (a.distinctCount != 0) { // set min element - if (ord.compare(acc.f0, a.f0) > 0) { - acc.f0 = a.f0 + if (ord.compare(acc.min, a.min) > 0) { + acc.min = a.min } // merge the count for each key - val iterator = a.f1.keySet().iterator() + val iterator = a.map.entries.iterator() while (iterator.hasNext) { - val key = iterator.next() - if (acc.f1.containsKey(key)) { - acc.f1.put(key, acc.f1.get(key) + a.f1.get(key)) + val entry = iterator.next() + val key = entry.getKey + val value = entry.getValue + val count = acc.map.get(key) + if (count != null) { + acc.map.put(key, count + value) } else { - acc.f1.put(key, a.f1.get(key)) + acc.map.put(key, value) + acc.distinctCount += 1 } } } @@ -132,15 +146,9 @@ abstract class MinWithRetractAggFunction[T](implicit ord: Ordering[T]) } def resetAccumulator(acc: MinWithRetractAccumulator[T]): Unit = { - acc.f0 = getInitValue - acc.f1.clear() - } - - override def getAccumulatorType: TypeInformation[MinWithRetractAccumulator[T]] = { - new TupleTypeInfo( - classOf[MinWithRetractAccumulator[T]], - getValueTypeInfo, - new MapTypeInfo(getValueTypeInfo, BasicTypeInfo.LONG_TYPE_INFO)) + acc.min = getInitValue + acc.distinctCount = 0L + acc.map.clear() } def getInitValue: T diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index f1386dfea46..f5cf1910b9a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -1250,23 +1250,17 @@ object AggregateUtil { aggregateInputTypes, tableConfig) - val (accumulatorType, accSpecs) = aggregateFunction match { - case collect: SqlAggFunction if collect.getKind == SqlKind.COLLECT => - removeStateViewFieldsFromAccTypeInfo( - uniqueIdWithinAggregate, - aggregate, - aggregate.getAccumulatorType, - isStateBackedDataViews) - - case udagg: AggSqlFunction => - removeStateViewFieldsFromAccTypeInfo( - uniqueIdWithinAggregate, - aggregate, - udagg.accType, - isStateBackedDataViews) + val (accumulatorType, accSpecs) = { + val accType = aggregateFunction match { + case udagg: AggSqlFunction => udagg.accType + case _ => getAccumulatorTypeOfAggregateFunction(aggregate) + } - case _ => - (getAccumulatorTypeOfAggregateFunction(aggregate), None) + removeStateViewFieldsFromAccTypeInfo( + uniqueIdWithinAggregate, + aggregate, + accType, + isStateBackedDataViews) } // create distinct accumulator filter argument diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala index bdd1df04894..f0e273215c5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AggFunctionTestBase.scala @@ -23,7 +23,7 @@ import java.math.BigDecimal import java.util.{ArrayList => JArrayList, List => JList} import org.apache.flink.table.functions.AggregateFunction -import org.apache.flink.table.functions.aggfunctions.{DecimalAvgAccumulator, DecimalSumWithRetractAccumulator} +import org.apache.flink.table.functions.aggfunctions.{DecimalAvgAccumulator, DecimalSumWithRetractAccumulator, MaxWithRetractAccumulator, MinWithRetractAccumulator} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils._ import org.junit.Assert.assertEquals import org.junit.Test @@ -137,6 +137,12 @@ abstract class AggFunctionTestBase[T, ACC] { case (e: BigDecimal, r: BigDecimal) => // BigDecimal.equals() value and scale but we are only interested in value. assert(e.compareTo(r) == 0) + case (e: MinWithRetractAccumulator[_], r: MinWithRetractAccumulator[_]) => + assertEquals(e.min, r.min) + assertEquals(e.distinctCount, r.distinctCount) + case (e: MaxWithRetractAccumulator[_], r: MaxWithRetractAccumulator[_]) => + assertEquals(e.max, r.max) + assertEquals(e.distinctCount, r.distinctCount) case _ => assertEquals(expected, result) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala index 0549339381d..5c1a780e217 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/AggFunctionHarnessTest.scala @@ -107,4 +107,84 @@ class AggFunctionHarnessTest extends HarnessTestBase { testHarness.close() } + + @Test + def testMinMaxAggFunctionWithRetract(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + + val data = new mutable.MutableList[(JInt, JInt, String)] + val t = env.fromCollection(data).toTable(tEnv, 'a, 'b, 'c) + tEnv.registerTable("T", t) + val sqlQuery = tEnv.sqlQuery( + s""" + |SELECT + | c, min(a), max(b) + |FROM ( + | SELECT a, b, c + | FROM T + | GROUP BY a, b, c + |) GROUP BY c + |""".stripMargin) + + val testHarness = createHarnessTester[String, CRow, CRow]( + sqlQuery.toRetractStream[Row](queryConfig), "groupBy") + + testHarness.setStateBackend(getStateBackend) + testHarness.open() + + val operator = getOperator(testHarness) + val minState = getState( + operator, + "function", + classOf[GroupAggProcessFunction], + "acc0_map_dataview").asInstanceOf[MapView[JInt, JInt]] + val maxState = getState( + operator, + "function", + classOf[GroupAggProcessFunction], + "acc1_map_dataview").asInstanceOf[MapView[JInt, JInt]] + assertTrue(minState.isInstanceOf[StateMapView[_, _]]) + assertTrue(maxState.isInstanceOf[StateMapView[_, _]]) + assertTrue(operator.getKeyedStateBackend.isInstanceOf[RocksDBKeyedStateBackend[_]]) + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + + testHarness.processElement(new StreamRecord(CRow(1: JInt, 1: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow("aaa", 1, 1), 1)) + + testHarness.processElement(new StreamRecord(CRow(1: JInt, 1: JInt, "bbb"), 1)) + expectedOutput.add(new StreamRecord(CRow("bbb", 1, 1), 1)) + + // min/max doesn't change + testHarness.processElement(new StreamRecord(CRow(2: JInt, 0: JInt, "aaa"), 1)) + + // min/max changed + testHarness.processElement(new StreamRecord(CRow(0: JInt, 2: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(false, "aaa", 1, 1), 1)) + expectedOutput.add(new StreamRecord(CRow("aaa", 0, 2), 1)) + + // retract the min/max value + testHarness.processElement(new StreamRecord(CRow(false, 0: JInt, 2: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(false, "aaa", 0, 2), 1)) + expectedOutput.add(new StreamRecord(CRow("aaa", 1, 1), 1)) + + // remove some state: state may be cleaned up by the state backend + // if not accessed beyond ttl time + operator.setCurrentKey(Row.of("aaa")) + minState.remove(1) + maxState.remove(1) + + // retract after state has been cleaned up + testHarness.processElement(new StreamRecord(CRow(false, 2: JInt, 0: JInt, "aaa"), 1)) + + testHarness.processElement(new StreamRecord(CRow(false, 1: JInt, 1: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(false, "aaa", 1, 1), 1)) + + val result = testHarness.getOutput + + verify(expectedOutput, result) + + testHarness.close() + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala index e4d938d4786..2e9dac5e7e5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/table/AggregateITCase.scala @@ -203,13 +203,13 @@ class AggregateITCase extends StreamingWithStateTestBase { .groupBy('b) .select('a.count as 'cnt, 'b) .groupBy('cnt) - .select('cnt, 'b.count as 'freq) + .select('cnt, 'b.count as 'freq, 'b.min as 'min, 'b.max as 'max) val results = t.toRetractStream[Row](queryConfig) results.addSink(new RetractingSink) env.execute() - val expected = List("1,1", "2,1", "3,1", "4,1", "5,1", "6,1") + val expected = List("1,1,1,1", "2,1,2,2", "3,1,3,3", "4,1,4,4", "5,1,5,5", "6,1,6,6") assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services