dtenedor commented on code in PR #49518:
URL: https://github.com/apache/spark/pull/49518#discussion_r1920578691
##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -3099,6 +3099,29 @@
],
"sqlState" : "42602"
},
+ "INVALID_RECURSIVE_REFERENCE" : {
+ "message" : [
+ "Invalid recursive reference found."
Review Comment:
this doesn't mention it's about the WITH clause (or similar dataframe API),
can we mention these specifically here so the user knows what part of the
query this is referring to
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -1043,6 +1044,75 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
if (Utils.isTesting) scrubOutIds(result) else result
}
+ /**
+ * Recursion, according to SQL standard, comes with several limitations:
+ * 1. Recursive term can contain one recursive reference only.
+ * 2. Recursive reference can't be used in some kinds of joins and
aggregations.
+ * This rule checks that these restrictions are not violated.
+ */
+ private def checkRecursion(
+ plan: LogicalPlan,
+ references: mutable.Map[Long, (Int, Seq[DataType])] =
mutable.Map.empty): Unit = {
+ plan match {
+ // The map is filled with UnionLoop id as key and 0 (number of Ref
occasions) and datatype
+ // as value
+ case UnionLoop(id, anchor, recursion, _) =>
+ checkRecursion(anchor, references)
+ checkRecursion(recursion, references += id -> (0,
anchor.output.map(_.dataType)))
+ references -= id
+ case r @ UnionLoopRef(loopId, output, false) =>
+ // If we encounter a recursive reference, it has to be present in the
map
+ if (!references.contains(loopId)) {
+ r.failAnalysis(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
+ messageParameters = Map.empty
+ )
+ }
+ val (count, dataType) = references(loopId)
+ if (count > 0) {
+ r.failAnalysis(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER",
+ messageParameters = Map.empty
+ )
+ }
+ val originalDataType = r.output.map(_.dataType)
+ if (!originalDataType.zip(dataType).forall {
+ case (odt, dt) => DataType.equalsStructurally(odt, dt, true)
Review Comment:
can you add an implementation comment for this check, it seems non trivial.
Why are we using this type of check for the data types?
##########
common/utils/src/main/resources/error/error-conditions.json:
##########
@@ -3099,6 +3099,29 @@
],
"sqlState" : "42602"
},
+ "INVALID_RECURSIVE_REFERENCE" : {
+ "message" : [
+ "Invalid recursive reference found."
+ ],
+ "subClass" : {
+ "DATA_TYPE" : {
+ "message" : [
+ "The data type of recursive references cannot change during
resolution. Originally it was <fromDataType> but after resolution is
<toDataType>."
Review Comment:
can you also mention what the user should do to modify the query to make it
succeed upon a subsequent attempt? Same below.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala:
##########
@@ -1043,6 +1044,75 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
if (Utils.isTesting) scrubOutIds(result) else result
}
+ /**
+ * Recursion, according to SQL standard, comes with several limitations:
+ * 1. Recursive term can contain one recursive reference only.
+ * 2. Recursive reference can't be used in some kinds of joins and
aggregations.
+ * This rule checks that these restrictions are not violated.
+ */
+ private def checkRecursion(
+ plan: LogicalPlan,
+ references: mutable.Map[Long, (Int, Seq[DataType])] =
mutable.Map.empty): Unit = {
+ plan match {
+ // The map is filled with UnionLoop id as key and 0 (number of Ref
occasions) and datatype
+ // as value
+ case UnionLoop(id, anchor, recursion, _) =>
+ checkRecursion(anchor, references)
+ checkRecursion(recursion, references += id -> (0,
anchor.output.map(_.dataType)))
+ references -= id
+ case r @ UnionLoopRef(loopId, output, false) =>
+ // If we encounter a recursive reference, it has to be present in the
map
+ if (!references.contains(loopId)) {
+ r.failAnalysis(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE",
+ messageParameters = Map.empty
+ )
+ }
+ val (count, dataType) = references(loopId)
+ if (count > 0) {
+ r.failAnalysis(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER",
+ messageParameters = Map.empty
+ )
+ }
+ val originalDataType = r.output.map(_.dataType)
+ if (!originalDataType.zip(dataType).forall {
+ case (odt, dt) => DataType.equalsStructurally(odt, dt, true)
+ }) {
+ r.failAnalysis(
+ errorClass = "INVALID_RECURSIVE_REFERENCE.DATA_TYPE",
+ messageParameters = Map(
+ "fromDataType" -> originalDataType.map(toSQLType).mkString(", "),
+ "toDataType" -> dataType.map(toSQLType).mkString(", ")
+ )
+ )
+ }
+ references(loopId) = (count + 1, dataType)
+ case Join(left, right, Inner, _, _) =>
+ checkRecursion(left, references)
Review Comment:
this algorithm is going to create a lot of stack frames. Could you please
convert it to a loop instead, starting with the initial operator to check, and
using a queue to add new nodes to check and popping them off after checking
them. In this way, we can improve performance and memory usage.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]