viirya commented on a change in pull request #33081: URL: https://github.com/apache/spark/pull/33081#discussion_r670139067
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala ########## @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends UnaryExpression Review comment: Could you add a few simple comments here? e.g, what `gapDuration` stands for. ########## File path: python/pyspark/sql/functions.py ########## @@ -2325,6 +2325,41 @@ def check_string_field(field, fieldName): return Column(res) +def session_window(timeColumn, gapDuration): + """ + Generates session window given a timestamp specifying column. + Session window is the one of dynamic windows, which means the length of window is vary Review comment: is vary -> is varying? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ########## @@ -345,4 +351,179 @@ object AggUtils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming session aggregation using the following progression: + * + * - Partial Aggregation + * - all tuples will have aggregated columns with initial value + * - (If "spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition" is enabled) + * - Sort within partition (sort: all keys) + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - Shuffle & Sort (distribution: keys "without" session, sort: all keys) + * - SessionWindowStateStoreRestore (group: keys "without" session) + * - merge input tuples with stored tuples (sessions) respecting sort order + * - MergingSessionExec + * - calculate session among tuples, and aggregate tuples in session with partial merge + * - NOTE: it leverages the fact that the output of SessionWindowStateStoreRestore is sorted + * - now there is at most 1 tuple per group, key with session + * - SessionWindowStateStoreSave (group: keys "without" session) + * - saves tuple(s) for the next batch (multiple sessions could co-exist at the same time) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregationForSession( + groupingExpressions: Seq[NamedExpression], + sessionExpression: NamedExpression, + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + stateFormatVersion: Int, + mergeSessionsInLocalPartition: Boolean, + child: SparkPlan): Seq[SparkPlan] = { + + val groupWithoutSessionExpression = groupingExpressions.filterNot { p => + p.semanticEquals(sessionExpression) + } + + if (groupWithoutSessionExpression.isEmpty) { + throw new AnalysisException("Global aggregation with session window in streaming query" + + " is not supported.") + } + + val groupingWithoutSessionAttributes = groupWithoutSessionExpression.map(_.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + // we don't do partial aggregate here, because it requires additional shuffle Review comment: "we don't do partial aggregate here"? I think the following does partial aggregate, no? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala ########## @@ -1610,6 +1610,26 @@ object SQLConf { .checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2") .createWithDefault(2) + val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION = + buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition") Review comment: Yea, is it necessary to have this config? Seems we can always do it. ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ########## @@ -113,6 +114,9 @@ object AggUtils { resultExpressions = partialResultExpressions, child = child) + val interExec: SparkPlan = mayAppendMergingSessionExec(groupingExpressions, Review comment: Could you add a comment here explaining why we need to append `MergingSessionExec`? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala ########## @@ -140,6 +144,8 @@ object AggUtils { resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { + val maySessionChild = mayAppendUpdatingSessionExec(groupingExpressions, child) Review comment: ditto ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/UpdatingSessionsIterator.scala ########## @@ -190,19 +191,21 @@ class UpdatingSessionsIterator( } private def closeCurrentSession(keyChanged: Boolean): Unit = { - assert(returnRowsIter == null || !returnRowsIter.hasNext) - returnRows = rowsForCurrentSession rowsForCurrentSession = null - val groupingKey = generateGroupingKey() + val groupingKey = generateGroupingKey().copy() val currentRowsIter = returnRows.generateIterator().map { internalRow => val valueRow = valueProj(internalRow) restoreProj(join2(groupingKey, valueRow)).copy() } - returnRowsIter = currentRowsIter + if (returnRowsIter != null && returnRowsIter.hasNext) { + returnRowsIter = returnRowsIter ++ currentRowsIter + } else { + returnRowsIter = currentRowsIter + } Review comment: why adding this change? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala ########## @@ -107,13 +116,14 @@ case class AggregateInPandasExec( // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + val newIter: Iterator[InternalRow] = mayAppendUpdatingSessionIterator(iter) Review comment: Could you also add a comment here? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/functions.scala ########## @@ -3630,6 +3630,35 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Generates session window given a timestamp specifying column. + * + * Session window is the one of dynamic windows, which means the length of window is vary Review comment: Session window is the one of dynamic windows -> Session window is one of dynamic windows ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ########## @@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] { } } +/** Maps a time column to a session window. */ +object SessionWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val SESSION_COL_NAME = "session_window" + private final val SESSION_START = "start" + private final val SESSION_END = "end" + + /** + * Generates the logical plan for generating session window on a timestamp column. + * Each session window is initially defined as [timestamp, timestamp + gap). + * + * This also adds a marker to the session column so that downstream can easily find the column + * on session window. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val sessionExpressions = + p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet + + val numWindowExpr = p.expressions.flatMap(_.collect { + case s: SessionWindow => s + case t: TimeWindow => t + }).toSet.size + + // Only support a single session expression for now + if (numWindowExpr == 1 && sessionExpressions.nonEmpty && + sessionExpressions.head.timeColumn.resolved && + sessionExpressions.head.checkInputDataTypes().isSuccess) { + + val session = sessionExpressions.head + + val metadata = session.timeColumn match { + case a: Attribute => a.metadata + case _ => Metadata.empty + } + + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putBoolean(SessionWindow.marker, true) + .build() + + val sessionAttr = AttributeReference( + SESSION_COL_NAME, session.dataType, metadata = newMetadata)() + + val sessionStart = PreciseTimestampConversion(session.timeColumn, TimestampType, LongType) + val sessionEnd = sessionStart + session.gapDuration + + val literalSessionStruct = CreateNamedStruct( + Literal(SESSION_START) :: + PreciseTimestampConversion(sessionStart, LongType, TimestampType) :: + Literal(SESSION_END) :: + PreciseTimestampConversion(sessionEnd, LongType, TimestampType) :: + Nil) + + val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)( + exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata)) + + val replacedPlan = p transformExpressions { + case s: SessionWindow => sessionAttr + } + + // For backwards compatibility we add a filter to filter out nulls + val filterExpr = IsNotNull(session.timeColumn) + + replacedPlan.withNewChildren( + Filter(filterExpr, + Project(sessionStruct +: child.output, child)) :: Nil) Review comment: Hmm, why we filter out null time column after evaluating session? Shouldn't we filter before evaluating session? ########## File path: sql/core/src/main/scala/org/apache/spark/sql/functions.scala ########## @@ -3630,6 +3630,35 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Generates session window given a timestamp specifying column. + * + * Session window is the one of dynamic windows, which means the length of window is vary Review comment: the length of window is vary -> the length of window is varying ########## File path: sql/core/src/main/scala/org/apache/spark/sql/functions.scala ########## @@ -3630,6 +3630,35 @@ object functions { window(timeColumn, windowDuration, windowDuration, "0 second") } + /** + * Generates session window given a timestamp specifying column. + * + * Session window is the one of dynamic windows, which means the length of window is vary + * according to the given inputs. The length of session window is defined as "the timestamp + * of latest input of the session + gap duration", so when the new inputs are bound to the + * current session window, the end time of session window can be expanded according to the new + * inputs. + * + * Windows can support microsecond precision. Windows in the order of months are not supported. Review comment: > "Windows in the order of months are not supported." I don't get what this sentence mean. ########## File path: python/pyspark/sql/functions.py ########## @@ -2325,6 +2325,41 @@ def check_string_field(field, fieldName): return Column(res) +def session_window(timeColumn, gapDuration): + """ + Generates session window given a timestamp specifying column. + Session window is the one of dynamic windows, which means the length of window is vary Review comment: Session window is the one of dynamic windows -> Session window is one of dynamic windows ########## File path: sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala ########## @@ -53,12 +54,23 @@ case class AggregateInPandasExec( override def producedAttributes: AttributeSet = AttributeSet(output) - override def requiredChildDistribution: Seq[Distribution] = { - if (groupingExpressions.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } + val sessionWindowOption = groupingExpressions.find { p => + p.metadata.contains(SessionWindow.marker) + } + + val groupingWithoutSessionExpressions = sessionWindowOption match { + case Some(sessionExpression) => + groupingExpressions.filterNot { p => p.semanticEquals(sessionExpression) } + + case None => groupingExpressions + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = sessionWindowOption match { + case Some(sessionExpression) => + Seq((groupingWithoutSessionExpressions ++ Seq(sessionExpression)) + .map(SortOrder(_, Ascending))) + + case None => Seq(groupingExpressions.map(SortOrder(_, Ascending))) Review comment: Hmm, why we change original `requiredChildDistribution`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
