[ https://issues.apache.org/jira/browse/FLINK-10674?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16695922#comment-16695922 ]
ASF GitHub Bot commented on FLINK-10674: ---------------------------------------- asfgit closed pull request #7147: [FLINK-10674] [table] Fix handling of retractions after clean up URL: https://github.com/apache/flink/pull/7147 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java b/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java index 47aac25e897..5531082611d 100644 --- a/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java +++ b/flink-libraries/flink-table-common/src/main/java/org/apache/flink/table/utils/EncodingUtils.java @@ -76,7 +76,7 @@ public static String encodeObjectToString(Serializable obj) { return instance; } catch (Exception e) { throw new ValidationException( - "Unable to deserialize string '" + base64String + "' of base class '" + baseClass.getName() + "'."); + "Unable to deserialize string '" + base64String + "' of base class '" + baseClass.getName() + "'.", e); } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala index 397032003ec..f591c4f2299 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/GroupAggProcessFunction.scala @@ -95,6 +95,12 @@ class GroupAggProcessFunction( var inputCnt = cntState.value() if (null == accumulators) { + // don't create a new accumulator for unknown retractions + // e.g. retractions that come in right after state clean up + if (!inputC.change) { + return + } + // first accumulate message firstRow = true accumulators = function.createAccumulators() } else { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala similarity index 65% rename from flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala rename to flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala index 7c4f5430328..1dce9946877 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/NonWindowHarnessTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/GroupAggregateHarnessTest.scala @@ -24,20 +24,18 @@ import org.apache.flink.api.common.time.Time import org.apache.flink.api.common.typeinfo.BasicTypeInfo import org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator import org.apache.flink.streaming.runtime.streamrecord.StreamRecord -import org.apache.flink.table.api.StreamQueryConfig import org.apache.flink.table.runtime.aggregate._ import org.apache.flink.table.runtime.harness.HarnessTestBase._ import org.apache.flink.table.runtime.types.CRow -import org.apache.flink.types.Row import org.junit.Test -class NonWindowHarnessTest extends HarnessTestBase { +class GroupAggregateHarnessTest extends HarnessTestBase { protected var queryConfig = new TestStreamQueryConfig(Time.seconds(2), Time.seconds(3)) @Test - def testNonWindow(): Unit = { + def testAggregate(): Unit = { val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow]( new GroupAggProcessFunction( @@ -54,50 +52,49 @@ class NonWindowHarnessTest extends HarnessTestBase { testHarness.open() + val expectedOutput = new ConcurrentLinkedQueue[Object]() + // register cleanup timer with 3001 testHarness.setProcessingTime(1) testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"), 1)) + expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 1)) // reuse timer 3001 testHarness.setProcessingTime(1000) testHarness.processElement(new StreamRecord(CRow(3L: JLong, 2: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(4L: JLong, 6: JInt), 1)) // register cleanup timer with 4002 testHarness.setProcessingTime(1002) testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(5L: JLong, 10: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 1)) + expectedOutput.add(new StreamRecord(CRow(6L: JLong, 3: JInt), 1)) // trigger cleanup timer and register cleanup timer with 7003 testHarness.setProcessingTime(4003) testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(7L: JLong, 5: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(8L: JLong, 11: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(9L: JLong, 18: JInt), 1)) testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 1)) + expectedOutput.add(new StreamRecord(CRow(10L: JLong, 3: JInt), 1)) val result = testHarness.getOutput - val expectedOutput = new ConcurrentLinkedQueue[Object]() - - expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(4L: JLong, 6: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(5L: JLong, 10: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(6L: JLong, 3: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(7L: JLong, 5: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(8L: JLong, 11: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(9L: JLong, 18: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(10L: JLong, 3: JInt), 1)) - verify(expectedOutput, result) testHarness.close() } @Test - def testNonWindowWithRetract(): Unit = { + def testAggregateWithRetract(): Unit = { val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow]( new GroupAggProcessFunction( @@ -114,42 +111,136 @@ class NonWindowHarnessTest extends HarnessTestBase { testHarness.open() + val expectedOutput = new ConcurrentLinkedQueue[Object]() + // register cleanup timer with 3001 testHarness.setProcessingTime(1) + // accumulate testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"), 1)) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1)) + + // accumulate testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"), 2)) + expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 2)) + + // retract for insertion testHarness.processElement(new StreamRecord(CRow(3L: JLong, 2: JInt, "aaa"), 3)) + expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 1: JInt), 3)) + expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 3)) + + // retract for deletion + testHarness.processElement(new StreamRecord(CRow(false, 3L: JLong, 2: JInt, "aaa"), 3)) + expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 3: JInt), 3)) + expectedOutput.add(new StreamRecord(CRow(3L: JLong, 1: JInt), 3)) + + // accumulate testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc"), 4)) + expectedOutput.add(new StreamRecord(CRow(4L: JLong, 3: JInt), 4)) // trigger cleanup timer and register cleanup timer with 6002 testHarness.setProcessingTime(3002) - testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 5)) - testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 6)) - testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 7)) - testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "eee"), 8)) - testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 9)) - testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 10)) - - val result = testHarness.getOutput - val expectedOutput = new ConcurrentLinkedQueue[Object]() + // retract after clean up + testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc"), 4)) - expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1: JInt), 1)) - expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1: JInt), 2)) - expectedOutput.add(new StreamRecord(CRow(false, 3L: JLong, 1: JInt), 3)) - expectedOutput.add(new StreamRecord(CRow(3L: JLong, 3: JInt), 3)) - expectedOutput.add(new StreamRecord(CRow(4L: JLong, 3: JInt), 4)) + // accumulate + testHarness.processElement(new StreamRecord(CRow(5L: JLong, 4: JInt, "aaa"), 5)) expectedOutput.add(new StreamRecord(CRow(5L: JLong, 4: JInt), 5)) + testHarness.processElement(new StreamRecord(CRow(6L: JLong, 2: JInt, "bbb"), 6)) expectedOutput.add(new StreamRecord(CRow(6L: JLong, 2: JInt), 6)) + + // retract + testHarness.processElement(new StreamRecord(CRow(7L: JLong, 5: JInt, "aaa"), 7)) expectedOutput.add(new StreamRecord(CRow(false, 7L: JLong, 4: JInt), 7)) expectedOutput.add(new StreamRecord(CRow(7L: JLong, 9: JInt), 7)) + + // accumulate + testHarness.processElement(new StreamRecord(CRow(8L: JLong, 6: JInt, "eee"), 8)) expectedOutput.add(new StreamRecord(CRow(8L: JLong, 6: JInt), 8)) + + // retract + testHarness.processElement(new StreamRecord(CRow(9L: JLong, 7: JInt, "aaa"), 9)) expectedOutput.add(new StreamRecord(CRow(false, 9L: JLong, 9: JInt), 9)) expectedOutput.add(new StreamRecord(CRow(9L: JLong, 16: JInt), 9)) + testHarness.processElement(new StreamRecord(CRow(10L: JLong, 3: JInt, "bbb"), 10)) expectedOutput.add(new StreamRecord(CRow(false, 10L: JLong, 2: JInt), 10)) expectedOutput.add(new StreamRecord(CRow(10L: JLong, 5: JInt), 10)) + val result = testHarness.getOutput + + verify(expectedOutput, result) + + testHarness.close() + } + + @Test + def testDistinctAggregateWithRetract(): Unit = { + + val processFunction = new LegacyKeyedProcessOperator[String, CRow, CRow]( + new GroupAggProcessFunction( + genDistinctCountAggFunction, + distinctCountAggregationStateType, + true, + queryConfig)) + + val testHarness = + createHarnessTester( + processFunction, + new TupleRowKeySelector[String](2), + BasicTypeInfo.STRING_TYPE_INFO) + + testHarness.open() + + val expectedOutput = new ConcurrentLinkedQueue[Object]() + + // register cleanup timer with 3001 + testHarness.setProcessingTime(1) + + // insert + testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"))) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong))) + testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"))) + expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong))) + + // distinct count retract then accumulate for downstream operators + testHarness.processElement(new StreamRecord(CRow(2L: JLong, 1: JInt, "bbb"))) + expectedOutput.add(new StreamRecord(CRow(false, 2L: JLong, 1L: JLong))) + expectedOutput.add(new StreamRecord(CRow(2L: JLong, 1L: JLong))) + + // update count for accumulate + testHarness.processElement(new StreamRecord(CRow(1L: JLong, 2: JInt, "aaa"))) + expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 1L: JLong))) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 2L: JLong))) + + // update count for retraction + testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 2: JInt, "aaa"))) + expectedOutput.add(new StreamRecord(CRow(false, 1L: JLong, 2L: JLong))) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong))) + + // insert + testHarness.processElement(new StreamRecord(CRow(4L: JLong, 3: JInt, "ccc"))) + expectedOutput.add(new StreamRecord(CRow(4L: JLong, 1L: JLong))) + + // retract entirely + testHarness.processElement(new StreamRecord(CRow(false, 4L: JLong, 3: JInt, "ccc"))) + expectedOutput.add(new StreamRecord(CRow(false, 4L: JLong, 1L: JLong))) + + // trigger cleanup timer and register cleanup timer with 6002 + testHarness.setProcessingTime(3002) + + // insert + testHarness.processElement(new StreamRecord(CRow(1L: JLong, 1: JInt, "aaa"))) + expectedOutput.add(new StreamRecord(CRow(1L: JLong, 1L: JLong))) + + // trigger cleanup timer and register cleanup timer with 9002 + testHarness.setProcessingTime(6002) + + // retract after cleanup + testHarness.processElement(new StreamRecord(CRow(false, 1L: JLong, 1: JInt, "aaa"))) + + val result = testHarness.getOutput + verify(expectedOutput, result) testHarness.close() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala index f70d991e50b..e5cceecc560 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/harness/HarnessTestBase.scala @@ -19,8 +19,9 @@ package org.apache.flink.table.runtime.harness import java.util.{Comparator, Queue => JQueue} +import org.apache.flink.api.common.state.{MapStateDescriptor, StateDescriptor} import org.apache.flink.api.common.time.Time -import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{INT_TYPE_INFO, LONG_TYPE_INFO, STRING_TYPE_INFO} +import org.apache.flink.api.common.typeinfo.BasicTypeInfo.{LONG_TYPE_INFO, STRING_TYPE_INFO} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.functions.KeySelector import org.apache.flink.api.java.typeutils.RowTypeInfo @@ -28,11 +29,11 @@ import org.apache.flink.streaming.api.operators.OneInputStreamOperator import org.apache.flink.streaming.api.watermark.Watermark import org.apache.flink.streaming.runtime.streamrecord.StreamRecord import org.apache.flink.streaming.util.{KeyedOneInputStreamOperatorTestHarness, TestHarnessUtil} -import org.apache.flink.table.api.StreamQueryConfig +import org.apache.flink.table.api.{StreamQueryConfig, Types} import org.apache.flink.table.codegen.GeneratedAggregationsFunction -import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction} -import org.apache.flink.table.functions.aggfunctions.{IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction} +import org.apache.flink.table.functions.aggfunctions.{CountAggFunction, IntSumWithRetractAggFunction, LongMaxWithRetractAggFunction, LongMinWithRetractAggFunction} import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction +import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction} import org.apache.flink.table.runtime.harness.HarnessTestBase.{RowResultSortComparator, RowResultSortComparatorWithWatermarks} import org.apache.flink.table.runtime.types.{CRow, CRowTypeInfo} import org.apache.flink.table.utils.EncodingUtils @@ -55,20 +56,16 @@ class HarnessTestBase { val intSumWithRetractAggFunction: String = EncodingUtils.encodeObjectToString(new IntSumWithRetractAggFunction) + val distinctCountAggFunction: String = + EncodingUtils.encodeObjectToString(new CountAggFunction()) + protected val MinMaxRowType = new RowTypeInfo(Array[TypeInformation[_]]( LONG_TYPE_INFO, STRING_TYPE_INFO, LONG_TYPE_INFO), Array("rowtime", "a", "b")) - protected val SumRowType = new RowTypeInfo(Array[TypeInformation[_]]( - LONG_TYPE_INFO, - INT_TYPE_INFO, - STRING_TYPE_INFO), - Array("a", "b", "c")) - protected val minMaxCRowType = new CRowTypeInfo(MinMaxRowType) - protected val sumCRowType = new CRowTypeInfo(SumRowType) protected val minMaxAggregates: Array[AggregateFunction[_, _]] = Array(new LongMinWithRetractAggFunction, @@ -77,15 +74,28 @@ class HarnessTestBase { protected val sumAggregates: Array[AggregateFunction[_, _]] = Array(new IntSumWithRetractAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]] + protected val distinctCountAggregates: Array[AggregateFunction[_, _]] = + Array(new CountAggFunction).asInstanceOf[Array[AggregateFunction[_, _]]] + protected val minMaxAggregationStateType: RowTypeInfo = new RowTypeInfo(minMaxAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*) protected val sumAggregationStateType: RowTypeInfo = new RowTypeInfo(sumAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*) + protected val distinctCountAggregationStateType: RowTypeInfo = + new RowTypeInfo(distinctCountAggregates.map(getAccumulatorTypeOfAggregateFunction(_)): _*) + + protected val distinctCountDescriptor: String = EncodingUtils.encodeObjectToString( + new MapStateDescriptor("distinctAgg0", distinctCountAggregationStateType, Types.LONG)) + + protected val minMaxFuncName = "MinMaxAggregateHelper" + protected val sumFuncName = "SumAggregationHelper" + protected val distinctCountFuncName = "DistinctCountAggregationHelper" + val minMaxCode: String = s""" - |public class MinMaxAggregateHelper + |public class $minMaxFuncName | extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations { | | transient org.apache.flink.table.functions.aggfunctions.LongMinWithRetractAggFunction @@ -94,7 +104,7 @@ class HarnessTestBase { | transient org.apache.flink.table.functions.aggfunctions.LongMaxWithRetractAggFunction | fmax = null; | - | public MinMaxAggregateHelper() throws Exception { + | public $minMaxFuncName() throws Exception { | | fmin = (org.apache.flink.table.functions.aggfunctions.LongMinWithRetractAggFunction) | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( @@ -207,25 +217,25 @@ class HarnessTestBase { val sumAggCode: String = s""" - |public final class SumAggregationHelper + |public final class $sumFuncName | extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations { | | - |transient org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction - |sum = null; - |private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache + | transient org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction + | sum = null; + | private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache | .flink.table.functions.aggfunctions.SumWithRetractAccumulator> accIt0 = | new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink | .table | .functions.aggfunctions.SumWithRetractAccumulator>(); | - | public SumAggregationHelper() throws Exception { + | public $sumFuncName() throws Exception { | - |sum = (org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction) - |${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( - | "$intSumWithRetractAggFunction", - | ${classOf[UserDefinedFunction].getCanonicalName}.class); - |} + | sum = (org.apache.flink.table.functions.aggfunctions.IntSumWithRetractAggFunction) + | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( + | "$intSumWithRetractAggFunction", + | ${classOf[UserDefinedFunction].getCanonicalName}.class); + | } | | public final void setAggregationResults( | org.apache.flink.types.Row accs, @@ -256,6 +266,12 @@ class HarnessTestBase { | public final void retract( | org.apache.flink.types.Row accs, | org.apache.flink.types.Row input) { + | + | sum.retract( + | ((org.apache.flink.table.functions.aggfunctions.SumWithRetractAccumulator) accs + | .getField + | (0)), + | (java.lang.Integer) input.getField(1)); | } | | public final org.apache.flink.types.Row createAccumulators() @@ -281,6 +297,162 @@ class HarnessTestBase { | input.getField(0)); | } | + | public final org.apache.flink.types.Row createOutputRow() { + | return new org.apache.flink.types.Row(2); + | } + | + | + | public final org.apache.flink.types.Row mergeAccumulatorsPair( + | org.apache.flink.types.Row a, + | org.apache.flink.types.Row b) + | { + | + | return a; + | + | } + | + | public final void resetAccumulator( + | org.apache.flink.types.Row accs) { + | } + | + | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) { + | } + | + | public void cleanup() { + | } + | + | public void close() { + | } + |} + |""".stripMargin + + val distinctCountAggCode: String = + s""" + |public final class $distinctCountFuncName + | extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations { + | + | final org.apache.flink.table.functions.aggfunctions.CountAggFunction count; + | + | final org.apache.flink.table.api.dataview.MapView acc0_distinctValueMap_dataview; + | + | final java.lang.reflect.Field distinctValueMap = + | org.apache.flink.api.java.typeutils.TypeExtractor.getDeclaredField( + | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator.class, + | "distinctValueMap"); + | + | + | private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache + | .flink.table.functions.aggfunctions.CountAccumulator> accIt0 = + | new org.apache.flink.table.runtime.aggregate.SingleElementIterable<org.apache.flink + | .table + | .functions.aggfunctions.CountAccumulator>(); + | + | public $distinctCountFuncName() throws Exception { + | + | count = (org.apache.flink.table.functions.aggfunctions.CountAggFunction) + | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( + | "$distinctCountAggFunction", + | ${classOf[UserDefinedFunction].getCanonicalName}.class); + | + | distinctValueMap.setAccessible(true); + | } + | + | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) { + | org.apache.flink.api.common.state.StateDescriptor acc0_distinctValueMap_dataview_desc = + | (org.apache.flink.api.common.state.StateDescriptor) + | ${classOf[EncodingUtils].getCanonicalName}.decodeStringToObject( + | "$distinctCountDescriptor", + | ${classOf[StateDescriptor[_, _]].getCanonicalName}.class, + | ctx.getUserCodeClassLoader()); + | acc0_distinctValueMap_dataview = new org.apache.flink.table.dataview.StateMapView( + | ctx.getMapState((org.apache.flink.api.common.state.MapStateDescriptor) + | acc0_distinctValueMap_dataview_desc)); + | } + | + | public final void setAggregationResults( + | org.apache.flink.types.Row accs, + | org.apache.flink.types.Row output) { + | + | org.apache.flink.table.functions.AggregateFunction baseClass0 = + | (org.apache.flink.table.functions.AggregateFunction) + | count; + | + | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 = + | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0); + | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 = + | (org.apache.flink.table.functions.aggfunctions.CountAccumulator) + | distinctAcc0.getRealAcc(); + | + | output.setField(1, baseClass0.getValue(acc0)); + | } + | + | public final void accumulate( + | org.apache.flink.types.Row accs, + | org.apache.flink.types.Row input) throws Exception { + | + | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 = + | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0); + | + | distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview); + | + | if (distinctAcc0.add( + | org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) { + | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 = + | (org.apache.flink.table.functions.aggfunctions.CountAccumulator) + | distinctAcc0.getRealAcc(); + | + | + | count.accumulate(acc0, (java.lang.Integer) input.getField(1)); + | } + | } + | + | public final void retract( + | org.apache.flink.types.Row accs, + | org.apache.flink.types.Row input) throws Exception { + | + | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 = + | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) accs.getField(0); + | + | distinctValueMap.set(distinctAcc0, acc0_distinctValueMap_dataview); + | + | if (distinctAcc0.remove( + | org.apache.flink.types.Row.of((java.lang.Integer) input.getField(1)))) { + | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 = + | (org.apache.flink.table.functions.aggfunctions.CountAccumulator) + | distinctAcc0.getRealAcc(); + | + | count.retract(acc0 , (java.lang.Integer) input.getField(1)); + | } + | } + | + | public final org.apache.flink.types.Row createAccumulators() + | { + | + | org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1); + | + | org.apache.flink.table.functions.aggfunctions.CountAccumulator acc0 = + | (org.apache.flink.table.functions.aggfunctions.CountAccumulator) + | count.createAccumulator(); + | org.apache.flink.table.functions.aggfunctions.DistinctAccumulator distinctAcc0 = + | (org.apache.flink.table.functions.aggfunctions.DistinctAccumulator) + | new org.apache.flink.table.functions.aggfunctions.DistinctAccumulator (acc0); + | accs.setField( + | 0, + | distinctAcc0); + | + | return accs; + | } + | + | public final void setForwardedFields( + | org.apache.flink.types.Row input, + | org.apache.flink.types.Row output) + | { + | + | output.setField( + | 0, + | input.getField(0)); + | } + | | public final void setConstantFlags(org.apache.flink.types.Row output) | { | @@ -304,10 +476,8 @@ class HarnessTestBase { | org.apache.flink.types.Row accs) { | } | - | public void open(org.apache.flink.api.common.functions.RuntimeContext ctx) { - | } - | | public void cleanup() { + | acc0_distinctValueMap_dataview.clear(); | } | | public void close() { @@ -315,12 +485,11 @@ class HarnessTestBase { |} |""".stripMargin - - protected val minMaxFuncName = "MinMaxAggregateHelper" - protected val sumFuncName = "SumAggregationHelper" - protected val genMinMaxAggFunction = GeneratedAggregationsFunction(minMaxFuncName, minMaxCode) protected val genSumAggFunction = GeneratedAggregationsFunction(sumFuncName, sumAggCode) + protected val genDistinctCountAggFunction = GeneratedAggregationsFunction( + distinctCountFuncName, + distinctCountAggCode) def createHarnessTester[IN, OUT, KEY]( operator: OneInputStreamOperator[IN, OUT], diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index c3da65f887a..46dde8e0225 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -263,6 +263,43 @@ class SqlITCase extends StreamingWithStateTestBase { assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) } + @Test + def testDistinctWithRetraction(): Unit = { + + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + val data = new mutable.MutableList[(Int, Long, String)] + data.+=((1, 1L, "Hi")) + data.+=((1, 1L, "Hi World")) + data.+=((1, 1L, "Test")) + data.+=((2, 1L, "Hi World")) + data.+=((2, 1L, "Test")) + data.+=((3, 1L, "Hi World")) + data.+=((3, 1L, "Hi World")) + data.+=((3, 1L, "Hi World")) + data.+=((4, 1L, "Hi World")) + data.+=((4, 1L, "Test")) + + val t = env.fromCollection(data).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("MyTable", t) + + // "1,1,3", "2,1,2", "3,1,1", "4,1,2" + val distinct = "SELECT a, COUNT(DISTINCT b) AS distinct_b, COUNT(DISTINCT c) AS distinct_c " + + "FROM MyTable GROUP BY a" + val nestedDistinct = s"SELECT distinct_b, COUNT(DISTINCT distinct_c) " + + s"FROM ($distinct) GROUP BY distinct_b" + + val result = tEnv.sqlQuery(nestedDistinct).toRetractStream[Row] + result.addSink(new StreamITCase.RetractingSink).setParallelism(1) + + env.execute() + + val expected = List("1,3") + assertEquals(expected.sorted, StreamITCase.retractedResults.sorted) + } + @Test def testUnboundedGroupByCollect(): Unit = { ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org > DistinctAccumulator.remove lead to NPE > -------------------------------------- > > Key: FLINK-10674 > URL: https://issues.apache.org/jira/browse/FLINK-10674 > Project: Flink > Issue Type: Bug > Components: Table API & SQL > Affects Versions: 1.6.1 > Environment: Flink 1.6.0 > Reporter: ambition > Assignee: winifredtang > Priority: Minor > Labels: pull-request-available > Attachments: image-2018-10-25-14-46-03-373.png > > > Our online Flink Job run about a week,job contain sql : > {code:java} > select `time`, > lower(trim(os_type)) as os_type, > count(distinct feed_id) as feed_total_view > from my_table > group by `time`, lower(trim(os_type)){code} > > then occur NPE: > > {code:java} > java.lang.NullPointerException > at scala.Predef$.Long2long(Predef.scala:363) > at > org.apache.flink.table.functions.aggfunctions.DistinctAccumulator.remove(DistinctAccumulator.scala:109) > at NonWindowedAggregationHelper$894.retract(Unknown Source) > at > org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction.processElement(GroupAggProcessFunction.scala:124) > at > org.apache.flink.table.runtime.aggregate.GroupAggProcessFunction.processElement(GroupAggProcessFunction.scala:39) > at > org.apache.flink.streaming.api.operators.LegacyKeyedProcessOperator.processElement(LegacyKeyedProcessOperator.java:88) > at > org.apache.flink.streaming.runtime.io.StreamInputProcessor.processInput(StreamInputProcessor.java:202) > at > org.apache.flink.streaming.runtime.tasks.OneInputStreamTask.run(OneInputStreamTask.java:105) > at > org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:300) > at org.apache.flink.runtime.taskmanager.Task.run(Task.java:711) > at java.lang.Thread.run(Thread.java:745) > {code} > > > View DistinctAccumulator.remove > !image-2018-10-25-14-46-03-373.png! > > this NPE should currentCnt = null lead to, so we simple handle like : > {code:java} > def remove(params: Row): Boolean = { > if(!distinctValueMap.contains(params)){ > true > }else{ > val currentCnt = distinctValueMap.get(params) > // > if (currentCnt == null || currentCnt == 1) { > distinctValueMap.remove(params) > true > } else { > var value = currentCnt - 1L > if(value < 0){ > value = 1 > } > distinctValueMap.put(params, value) > false > } > } > }{code} > > -- This message was sent by Atlassian JIRA (v7.6.3#76005)