xuanyuanking commented on a change in pull request #33081:
URL: https://github.com/apache/spark/pull/33081#discussion_r670116061
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
##########
@@ -3933,6 +3938,83 @@ object TimeWindowing extends Rule[LogicalPlan] {
}
}
+/** Maps a time column to a session window. */
+object SessionWindowing extends Rule[LogicalPlan] {
+ import org.apache.spark.sql.catalyst.dsl.expressions._
+
+ private final val SESSION_COL_NAME = "session_window"
+ private final val SESSION_START = "start"
+ private final val SESSION_END = "end"
+
+ /**
+ * Generates the logical plan for generating session window on a timestamp
column.
+ * Each session window is initially defined as [timestamp, timestamp + gap).
+ *
+ * This also adds a marker to the session column so that downstream can
easily find the column
+ * on session window.
+ */
+ def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
+ case p: LogicalPlan if p.children.size == 1 =>
+ val child = p.children.head
+ val sessionExpressions =
+ p.expressions.flatMap(_.collect { case s: SessionWindow => s }).toSet
+
+ val numWindowExpr = p.expressions.flatMap(_.collect {
+ case s: SessionWindow => s
+ case t: TimeWindow => t
+ }).toSet.size
+
+ // Only support a single session expression for now
+ if (numWindowExpr == 1 && sessionExpressions.nonEmpty &&
+ sessionExpressions.head.timeColumn.resolved &&
+ sessionExpressions.head.checkInputDataTypes().isSuccess) {
+
+ val session = sessionExpressions.head
+
+ val metadata = session.timeColumn match {
+ case a: Attribute => a.metadata
+ case _ => Metadata.empty
+ }
+
+ val newMetadata = new MetadataBuilder()
+ .withMetadata(metadata)
+ .putBoolean(SessionWindow.marker, true)
+ .build()
+
+ val sessionAttr = AttributeReference(
+ SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
+
+ val sessionStart = PreciseTimestampConversion(session.timeColumn,
TimestampType, LongType)
+ val sessionEnd = sessionStart + session.gapDuration
+
+ val literalSessionStruct = CreateNamedStruct(
+ Literal(SESSION_START) ::
+ PreciseTimestampConversion(sessionStart, LongType, TimestampType)
::
+ Literal(SESSION_END) ::
+ PreciseTimestampConversion(sessionEnd, LongType, TimestampType) ::
+ Nil)
+
+ val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
+ exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
+
+ val replacedPlan = p transformExpressions {
+ case s: SessionWindow => sessionAttr
+ }
+
+ // For backwards compatibility we add a filter to filter out nulls
Review comment:
Didn't fully get the point of backward compatibility here, could you
explain more?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -1610,6 +1610,26 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)
+ val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
+
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
Review comment:
How about `spark.sql.streaming.sessionWindow.localMerge.enabled` or
`spark.sql.streaming.sessionWindow. mergeSessionsInLocalPartition.enabled`?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -1610,6 +1610,26 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)
+ val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
+
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
Review comment:
Comparing with the similar logic of AggUtils, maybe we can also remove
this config? Just always do the local merge?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
##########
@@ -1610,6 +1610,26 @@ object SQLConf {
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)
+ val STREAMING_SESSION_WINDOW_MERGE_SESSIONS_IN_LOCAL_PARTITION =
+
buildConf("spark.sql.streaming.sessionWindow.merge.sessions.in.local.partition")
+ .internal()
+ .doc("When true, streaming session window sorts and merge sessions in
local partition " +
+ "prior to shuffle. This is to reduce the rows to shuffle, but only
beneficial when " +
+ "there're lots of rows in a batch being assigned to same sessions.")
+ .booleanConf
Review comment:
nit: version
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
##########
@@ -335,12 +339,29 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
}
}
- AggUtils.planStreamingAggregation(
- normalizedGroupingExpressions,
- aggregateExpressions.map(expr =>
expr.asInstanceOf[AggregateExpression]),
- rewrittenResultExpressions,
- stateVersion,
- planLater(child))
+ sessionWindowOption match {
+ case Some(sessionWindow) =>
+ val stateVersion =
conf.getConf(SQLConf.STREAMING_SESSION_WINDOW_STATE_FORMAT_VERSION)
+
+ AggUtils.planStreamingAggregationForSession(
+ normalizedGroupingExpressions,
+ sessionWindow,
+ aggregateExpressions.map(expr =>
expr.asInstanceOf[AggregateExpression]),
+ rewrittenResultExpressions,
+ stateVersion,
+ conf.streamingSessionWindowMergeSessionInLocalPartition,
+ planLater(child))
+
+ case None =>
+ val stateVersion =
conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
Review comment:
nit: we get the aggregation state format version twice
https://github.com/apache/spark/pull/33081/files#diff-21f071d73070b8257ad76e6e16ec5ed38a13d1278fe94bd42546c258a69f4410R327
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
##########
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends
UnaryExpression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with NonSQLExpression {
+
+ //////////////////////////
+ // SQL Constructors
+ //////////////////////////
+
+ def this(timeColumn: Expression, gapDuration: Expression) = {
+ this(timeColumn, SessionWindow.parseExpression(gapDuration))
+ }
+
+ override def child: Expression = timeColumn
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = new StructType()
+ .add(StructField("start", TimestampType))
+ .add(StructField("end", TimestampType))
+
+ // This expression is replaced in the analyzer.
+ override lazy val resolved = false
+
+ /** Validate the inputs for the gap duration in addition to the input data
type. */
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val dataTypeCheck = super.checkInputDataTypes()
+ if (dataTypeCheck.isSuccess) {
+ if (gapDuration <= 0) {
+ return TypeCheckFailure(s"The window duration ($gapDuration) must be
greater than 0.")
+ }
+ }
+ dataTypeCheck
+ }
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression =
+ copy(timeColumn = newChild)
+}
+
+object SessionWindow {
+ val marker = "spark.sessionWindow"
+
+ /**
+ * Parses the interval string for a valid time duration. CalendarInterval
expects interval
+ * strings to start with the string `interval`. For usability, we prepend
`interval` to the string
+ * if the user omitted it.
+ *
+ * @param interval The interval string
+ * @return The interval duration in microseconds. SparkSQL casts
TimestampType has microsecond
+ * precision.
+ */
+ private def getIntervalInMicroSeconds(interval: String): Long = {
Review comment:
Seems we can directly use TimeWindow. getIntervalInMicroSeconds?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
##########
@@ -345,4 +351,179 @@ object AggUtils {
finalAndCompleteAggregate :: Nil
}
+
+ /**
+ * Plans a streaming session aggregation using the following progression:
+ *
+ * - Partial Aggregation
+ * - all tuples will have aggregated columns with initial value
Review comment:
nit: maybe we can mention `UpdateSessionsExec` here.
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SessionWindow.scala
##########
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.util.{DateTimeConstants, IntervalUtils}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+case class SessionWindow(timeColumn: Expression, gapDuration: Long) extends
UnaryExpression
+ with ImplicitCastInputTypes
+ with Unevaluable
+ with NonSQLExpression {
+
+ //////////////////////////
+ // SQL Constructors
+ //////////////////////////
+
+ def this(timeColumn: Expression, gapDuration: Expression) = {
+ this(timeColumn, SessionWindow.parseExpression(gapDuration))
+ }
+
+ override def child: Expression = timeColumn
+ override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
+ override def dataType: DataType = new StructType()
+ .add(StructField("start", TimestampType))
+ .add(StructField("end", TimestampType))
+
+ // This expression is replaced in the analyzer.
+ override lazy val resolved = false
+
+ /** Validate the inputs for the gap duration in addition to the input data
type. */
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val dataTypeCheck = super.checkInputDataTypes()
+ if (dataTypeCheck.isSuccess) {
+ if (gapDuration <= 0) {
+ return TypeCheckFailure(s"The window duration ($gapDuration) must be
greater than 0.")
+ }
+ }
+ dataTypeCheck
+ }
+
+ override protected def withNewChildInternal(newChild: Expression):
Expression =
+ copy(timeColumn = newChild)
+}
+
+object SessionWindow {
+ val marker = "spark.sessionWindow"
+
+ /**
+ * Parses the interval string for a valid time duration. CalendarInterval
expects interval
+ * strings to start with the string `interval`. For usability, we prepend
`interval` to the string
+ * if the user omitted it.
+ *
+ * @param interval The interval string
+ * @return The interval duration in microseconds. SparkSQL casts
TimestampType has microsecond
+ * precision.
+ */
+ private def getIntervalInMicroSeconds(interval: String): Long = {
+ val cal = IntervalUtils.stringToInterval(UTF8String.fromString(interval))
+ if (cal.months != 0) {
+ throw new IllegalArgumentException(
+ s"Intervals greater than a month is not supported ($interval).")
+ }
+ cal.days * DateTimeConstants.MICROS_PER_DAY + cal.microseconds
+ }
+
+ /**
+ * Parses the duration expression to generate the long value for the
original constructor so
+ * that we can use `window` in SQL.
+ */
+ private def parseExpression(expr: Expression): Long = expr match {
Review comment:
ditto, can we directly use TimeWindow.parseExpression?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]