[ https://issues.apache.org/jira/browse/FLINK-6250?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15978902#comment-15978902 ]
ASF GitHub Bot commented on FLINK-6250: --------------------------------------- Github user fhueske commented on a diff in the pull request: https://github.com/apache/flink/pull/3732#discussion_r112708098 --- Diff: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/ProcTimeBoundedDistinctRowsOver.scala --- @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime.aggregate + +import java.util + +import org.apache.flink.configuration.Configuration +import org.apache.flink.streaming.api.functions.ProcessFunction +import org.apache.flink.types.Row +import org.apache.flink.util.{Collector, Preconditions} +import org.apache.flink.api.common.state.ValueStateDescriptor +import org.apache.flink.api.java.typeutils.RowTypeInfo +import org.apache.flink.api.common.state.ValueState +import org.apache.flink.table.functions.{Accumulator, AggregateFunction} +import org.apache.flink.api.common.state.MapState +import org.apache.flink.api.common.state.MapStateDescriptor +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ListTypeInfo +import java.util.{List => JList} + +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.common.state.ListState + +class ProcTimeBoundedDistinctRowsOver( + private val aggregates: Array[AggregateFunction[_]], + private val aggFields: Array[Array[Int]], + private val distinctAggsFlag: Array[Boolean], + private val precedingOffset: Long, + private val forwardedFieldCount: Int, + private val aggregatesTypeInfo: RowTypeInfo, + private val inputType: TypeInformation[Row]) + extends ProcessFunction[Row, Row] { + + Preconditions.checkNotNull(aggregates) + Preconditions.checkNotNull(aggFields) + Preconditions.checkNotNull(distinctAggsFlag) + Preconditions.checkNotNull(distinctAggsFlag.length == aggregates.length) + Preconditions.checkArgument(aggregates.length == aggFields.length) + Preconditions.checkArgument(precedingOffset > 0) + + private var accumulatorState: ValueState[Row] = _ + private var rowMapState: MapState[Long, JList[Row]] = _ + private var output: Row = _ + private var counterState: ValueState[Long] = _ + private var smallestTsState: ValueState[Long] = _ + private var distinctValueStateList: Array[MapState[Any, Long]] = _ + + override def open(config: Configuration) { + + output = new Row(forwardedFieldCount + aggregates.length) + // We keep the elements received in a Map state keyed + // by the ingestion time in the operator. + // we also keep counter of processed elements + // and timestamp of oldest element + val rowListTypeInfo: TypeInformation[JList[Row]] = + new ListTypeInfo[Row](inputType).asInstanceOf[TypeInformation[JList[Row]]] + + val mapStateDescriptor: MapStateDescriptor[Long, JList[Row]] = + new MapStateDescriptor[Long, JList[Row]]("windowBufferMapState", + BasicTypeInfo.LONG_TYPE_INFO.asInstanceOf[TypeInformation[Long]], rowListTypeInfo) + rowMapState = getRuntimeContext.getMapState(mapStateDescriptor) + + val aggregationStateDescriptor: ValueStateDescriptor[Row] = + new ValueStateDescriptor[Row]("aggregationState", aggregatesTypeInfo) + accumulatorState = getRuntimeContext.getState(aggregationStateDescriptor) + + val processedCountDescriptor : ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("processedCountState", classOf[Long]) + counterState = getRuntimeContext.getState(processedCountDescriptor) + + val smallestTimestampDescriptor : ValueStateDescriptor[Long] = + new ValueStateDescriptor[Long]("smallestTSState", classOf[Long]) + smallestTsState = getRuntimeContext.getState(smallestTimestampDescriptor) + distinctValueStateList = new Array(aggregates.size) + for(i <- 0 until aggregates.size){ + if(distinctAggsFlag(i)){ + val distinctValDescriptor = new MapStateDescriptor[Any, Long]( + "distinctValuesBufferMapState" + i, + classOf[Any], + classOf[Long]) + distinctValueStateList(i)=getRuntimeContext.getMapState(distinctValDescriptor) + } + } + } + + override def processElement( + input: Row, + ctx: ProcessFunction[Row, Row]#Context, + out: Collector[Row]): Unit = { + + val currentTime = ctx.timerService.currentProcessingTime + var i = 0 + + // initialize state for the processed element + var accumulators = accumulatorState.value + if (accumulators == null) { + accumulators = new Row(aggregates.length) + while (i < aggregates.length) { + accumulators.setField(i, aggregates(i).createAccumulator()) + i += 1 + } + } + + // get smallest timestamp + var smallestTs = smallestTsState.value + if (smallestTs == 0L) { + smallestTs = currentTime + smallestTsState.update(smallestTs) + } + // get previous counter value + var counter = counterState.value + + if (counter == precedingOffset) { + val retractList = rowMapState.get(smallestTs) + + // get oldest element beyond buffer size + // and if oldest element exist, retract value + var removeCounter :Integer = 0 + var distinctCounter : Integer = 0 + var retractVal : Object = null + i = 0 + while (i < aggregates.length) { + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + retractVal = retractList.get(0).getField(aggFields(i)(0)) + if(distinctAggsFlag(i)){ + var distinctValCounter: Long = distinctValueStateList(i).get(retractVal) + // if the value to be retract is the last one added + // the remove it and retract the value + if(distinctValCounter == 1L){ + aggregates(i).retract(accumulator, retractVal) + distinctValueStateList(i).remove(retractVal) + } // if the are other values in the buffer + // decrease the counter and continue + else { + distinctValCounter -= 1 + distinctValueStateList(i).put(retractVal,distinctValCounter) + } + }else { + aggregates(i).retract(accumulator, retractVal) + } + i += 1 + } + retractList.remove(0) + // if reference timestamp list not empty, keep the list + if (!retractList.isEmpty) { + rowMapState.put(smallestTs, retractList) + } // if smallest timestamp list is empty, remove and find new smallest + else { + rowMapState.remove(smallestTs) + val iter = rowMapState.keys.iterator + var currentTs: Long = 0L + var newSmallestTs: Long = Long.MaxValue + while (iter.hasNext) { + currentTs = iter.next + if (currentTs < newSmallestTs) { + newSmallestTs = currentTs + } + } + smallestTsState.update(newSmallestTs) + } + } // we update the counter only while buffer is getting filled + else { + counter += 1 + counterState.update(counter) + } + + // copy forwarded fields in output row + i = 0 + while (i < forwardedFieldCount) { + output.setField(i, input.getField(i)) + i += 1 + } + + // accumulate current row and set aggregate in output row + i = 0 + while (i < aggregates.length) { + val index = forwardedFieldCount + i + val accumulator = accumulators.getField(i).asInstanceOf[Accumulator] + val inputValue = input.getField(aggFields(i)(0)) + // check if distinct aggregation + if(distinctAggsFlag(i)){ + // if first time we see value, set counter and aggregate + var distinctValCounter: Long = distinctValueStateList(i).get(inputValue) + // if counter is 0L first time we aggregate + // for a seen value but never accumulated + if(distinctValCounter == 0L){ --- End diff -- doesn't `MapState.get` return `null` when the key is not contained? > Distinct procTime with Rows boundaries > -------------------------------------- > > Key: FLINK-6250 > URL: https://issues.apache.org/jira/browse/FLINK-6250 > Project: Flink > Issue Type: Sub-task > Components: Table API & SQL > Reporter: radu > Assignee: Stefano Bortoli > > Support proctime with rows boundaries > Q1.1. `SELECT SUM( DISTINCT b) OVER (ORDER BY procTime() ROWS BETWEEN 2 > PRECEDING AND CURRENT ROW) FROM stream1` > Q1.1. `SELECT COUNT(b), SUM( DISTINCT b) OVER (ORDER BY procTime() ROWS > BETWEEN 2 PRECEDING AND CURRENT ROW) FROM stream1` -- This message was sent by Atlassian JIRA (v6.3.15#6346)