gerashegalov commented on code in PR #41203: URL: https://github.com/apache/spark/pull/41203#discussion_r1197205234
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala: ########## @@ -74,3 +74,44 @@ object ExpectsInputTypes extends QueryErrorsBase { trait ImplicitCastInputTypes extends ExpectsInputTypes { // No other methods } + +/** + * An extension of the ExpectsInputTypes trait that also checks each + * input expression's foldable attribute against inputIsFoldable; + * inputInputFoldable is a Seq of type Option[Boolean], specifying + * a value of None bypasses the check for the given column + */ +trait ExpectsInputTypesAndFoldable extends ExpectsInputTypes { + def inputIsFoldable: Seq[Option[Boolean]] + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + ExpectsInputTypesAndFoldable.checkInputIsFoldable(children, inputIsFoldable) + case failure => failure + } + } +} + +object ExpectsInputTypesAndFoldable extends QueryErrorsBase { + + def checkInputIsFoldable( + inputs: Seq[Expression], + inputIsFoldable: Seq[Option[Boolean]]): TypeCheckResult = { + val mismatch = inputs.zip(inputIsFoldable).zipWithIndex.filter { + case ((_, Some(_)), _) => true + case _ => false + }.collectFirst { + case ((input, expected), idx) if input.foldable != expected.get => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_FOLDABLE_VALUE", + messageParameters = Map( + "paramIndex" -> (idx + 1).toString, + "inputSql" -> toSQLExpr(input), + "inputFoldable" -> input.foldable.toString, + "requiredFoldable" -> expected.get.toString)) + } Review Comment: nit: could make collectFirst subsume filter and extract expected out of the Option container. ```suggestion val mismatch = inputs.zip(inputIsFoldable).zipWithIndex.collectFirst { case ((input, Some(expected)), idx) if input.foldable != expected => DataTypeMismatch( errorSubClass = "UNEXPECTED_INPUT_FOLDABLE_VALUE", messageParameters = Map( "paramIndex" -> (idx + 1).toString, "inputSql" -> toSQLExpr(input), "inputFoldable" -> input.foldable.toString, "requiredFoldable" -> expected.toString)) } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org