davidm-db commented on code in PR #48794:
URL: https://github.com/apache/spark/pull/48794#discussion_r1850890377
##########
sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala:
##########
@@ -636,3 +649,223 @@ class LoopStatementExec(
body.reset()
}
}
+
+/**
+ * Executable node for ForStatement.
+ * @param query Executable node for the query.
+ * @param variableName Name of variable used for accessing current row during
iteration.
+ * @param body Executable node for the body.
+ * @param label Label set to ForStatement by user or None otherwise.
+ * @param session Spark session that SQL script is executed within.
+ */
+class ForStatementExec(
+ query: SingleStatementExec,
+ variableName: Option[String],
+ body: CompoundBodyExec,
+ val label: Option[String],
+ session: SparkSession) extends NonLeafStatementExec {
+
+ private object ForState extends Enumeration {
+ val VariableAssignment, Body, VariableCleanup = Value
+ }
+ private var state = ForState.VariableAssignment
+ private var currRow = 0
+ private var areVariablesDeclared = false
+
+ // map of all variables created internally by the for statement
+ // (variableName -> variableExpression)
+ private var variablesMap: Map[String, Expression] = Map()
+
+ // compound body used for dropping variables while in
ForState.VariableAssignment
+ private var dropVariablesExec: CompoundBodyExec = null
+
+ private var queryResult: Array[Row] = null
+ private var isResultCacheValid = false
+ private def cachedQueryResult(): Array[Row] = {
+ if (!isResultCacheValid) {
+ queryResult = query.buildDataFrame(session).collect()
+ isResultCacheValid = true
+ }
+ queryResult
+ }
+
+ /**
+ * For can be interrupted by LeaveStatementExec
+ */
+ private var interrupted: Boolean = false
+
+ private lazy val treeIterator: Iterator[CompoundStatementExec] =
+ new Iterator[CompoundStatementExec] {
+
+ override def hasNext: Boolean = {
+ val resultSize = cachedQueryResult().length
+ (state == ForState.VariableCleanup &&
dropVariablesExec.getTreeIterator.hasNext) ||
+ (!interrupted && resultSize > 0 && currRow < resultSize)
+ }
+
+ override def next(): CompoundStatementExec = state match {
+
+ case ForState.VariableAssignment =>
+ variablesMap =
createVariablesMapFromRow(cachedQueryResult()(currRow))
+
+ if (!areVariablesDeclared) {
+ // create and execute declare var statements
+ variablesMap.keys.toSeq
+ .map(colName => createDeclareVarExec(colName,
variablesMap(colName)))
+ .foreach(declareVarExec =>
declareVarExec.buildDataFrame(session).collect())
+ areVariablesDeclared = true
+ }
+
+ // create and execute set var statements
+ variablesMap.keys.toSeq
+ .map(colName => createSetVarExec(colName, variablesMap(colName)))
+ .foreach(setVarExec =>
setVarExec.buildDataFrame(session).collect())
+
+ state = ForState.Body
+ body.reset()
+ next()
+
+ case ForState.Body =>
+ val retStmt = body.getTreeIterator.next()
+
+ // Handle LEAVE or ITERATE statement if it has been encountered.
+ retStmt match {
+ case leaveStatementExec: LeaveStatementExec if
!leaveStatementExec.hasBeenMatched =>
+ if (label.contains(leaveStatementExec.label)) {
+ leaveStatementExec.hasBeenMatched = true
+ }
+ interrupted = true
+ // If this for statement encounters LEAVE, it will either not be
executed
+ // again, or it will be reset before being executed.
+ // In either case, variables will not
+ // be dropped normally, from ForState.VariableCleanup, so we
drop them here.
+ dropVars()
+ return retStmt
+ case iterStatementExec: IterateStatementExec if
!iterStatementExec.hasBeenMatched =>
+ if (label.contains(iterStatementExec.label)) {
+ iterStatementExec.hasBeenMatched = true
+ } else {
+ // if an outer loop is being iterated, this for statement will
either not be
+ // executed again, or it will be reset before being executed.
+ // In either case, variables will not
+ // be dropped normally, from ForState.VariableCleanup, so we
drop them here.
+ dropVars()
+ }
+ switchStateFromBody()
+ return retStmt
+ case _ =>
+ }
+
+ if (!body.getTreeIterator.hasNext) {
+ switchStateFromBody()
+ }
+ retStmt
+
+ case ForState.VariableCleanup =>
+ dropVariablesExec.getTreeIterator.next()
+ }
+ }
+
+ /**
+ * Recursively creates a Catalyst expression from Scala value.<br>
+ * See https://spark.apache.org/docs/latest/sql-ref-datatypes.html for Spark
-> Scala mappings
+ */
+ private def createExpressionFromValue(value: Any): Expression = value match {
+ case m: Map[_, _] =>
+ // arguments of CreateMap are in the format: (key1, val1, key2, val2,
...)
+ val mapArgs = m.keys.toSeq.flatMap { key =>
+ Seq(createExpressionFromValue(key), createExpressionFromValue(m(key)))
+ }
+ CreateMap(mapArgs, useStringTypeWhenEmpty = false)
+
+ // structs and rows match this case
+ case s: Row =>
+ // arguments of CreateNamedStruct are in the format: (name1, val1, name2,
val2, ...)
+ val namedStructArgs = s.schema.names.toSeq.flatMap { colName =>
+ val valueExpression = createExpressionFromValue(s.getAs(colName))
+ Seq(Literal(colName), valueExpression)
+ }
+ CreateNamedStruct(namedStructArgs)
+
+ // arrays match this case
+ case a: collection.Seq[_] =>
+ val arrayArgs = a.toSeq.map(createExpressionFromValue(_))
+ CreateArray(arrayArgs, useStringTypeWhenEmpty = false)
+ case _ => Literal(value)
Review Comment:
nit: add one empty line above, I started looking for where this case
disappeared compared to the previous iterations 😂 it'll be a bit cleaner
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]