This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d999f622dc6 [SPARK-45117][SQL] Implement missing otherCopyArgs for the MultiCommutativeOp expression d999f622dc6 is described below commit d999f622dc68b4fb2734e2ac7cbe203b062c257f Author: Supun Nakandala <supun.nakand...@databricks.com> AuthorDate: Tue Sep 12 23:52:22 2023 +0800 [SPARK-45117][SQL] Implement missing otherCopyArgs for the MultiCommutativeOp expression ### What changes were proposed in this pull request? - This PR implements the missing otherCopyArgs in the MultiCommutativeOp expression ### Why are the changes needed? - Without this method implementation, calling toJSON will throw an exception from the TreeNode::jsonFields method. - This is because the jsonFields method has an assertion that the number of fields defined in the constructor is equal to the number of field values (productIterator.toSeq ++ otherCopyArgs). - The originalRoot field of the MultiCommutativeOp is not part of the productIterator. Hence, it has to be explicitly set in the otherCopyArgs field. ### Does this PR introduce _any_ user-facing change? - No ### How was this patch tested? - Added unit test ### Was this patch authored or co-authored using generative AI tooling? - No Closes #42873 from db-scnakandala/multi-commutative-op. Authored-by: Supun Nakandala <supun.nakand...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 2 ++ .../spark/sql/catalyst/expressions/CanonicalizeSuite.scala | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c2330cdb59d..bd7369e57b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1410,4 +1410,6 @@ case class MultiCommutativeOp( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = this.copy(operands = newChildren)(originalRoot) + + override protected final def otherCopyArgs: Seq[AnyRef] = originalRoot :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index 0e22b0d2876..89175ea1970 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -338,4 +338,17 @@ class CanonicalizeSuite extends SparkFunSuite { SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) } + + test("toJSON works properly with MultiCommutativeOp") { + val default = SQLConf.get.getConf(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, "1") + + val d = Decimal(1.2) + val literal1 = Literal.create(d, DecimalType(2, 1)) + val literal2 = Literal.create(d, DecimalType(2, 1)) + val literal3 = Literal.create(d, DecimalType(3, 2)) + val op = Add(literal1, Add(literal2, literal3)) + assert(op.canonicalized.toJSON.nonEmpty) + SQLConf.get.setConfString(MULTI_COMMUTATIVE_OP_OPT_THRESHOLD.key, default.toString) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org