MaxGekk commented on code in PR #49029: URL: https://github.com/apache/spark/pull/49029#discussion_r1880465089
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExplicitlyUnsupportedResolverFeature.scala: ########## @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +/** + * This is an addon to [[ResolverGuard]] functionality for features that cannot be determined by + * only looking at the unresolved plan. [[Resolver]] will throw this control-flow exception + * when it encounters some explicitly unsupported feature. It will later be caught by + * [[HybridAnalyzer]] to abort single-pass analysis without comparing single-pass and fixed-point + * results. The motivation for this feature is the same as for the [[ResolverGuard]] - we want to + * have an explicit allowlist of the unimplemented features that we are aware of, and + * `UNSUPPORTED_SINGLE_PASS_ANALYZER_FEATURE` will signal us the rest of the gaps. + * + * For example, [[UnresolvedRelation]] can be intermediately resolved by [[ResolveRelations]] as + * [[UnresolvedCatalogRelation]] or a [[View]] (among all others). Say that for now the views + * are not implemented, and we are aware of that, so [[ExplicitlyUnsupportedResolverFeature]] will + * be thrown in the middile of the single-pass analysis to abort it. Review Comment: ```suggestion * be thrown in the middle of the single-pass analysis to abort it. ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + Review Comment: ```suggestion "Input types of a unary minus must be valid, but got: " + ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + Review Comment: interpolation is not needed here: ```suggestion "Input types of a predicate must be valid, but got: " + ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + + s"${unaryMinus.child.dataType.typeName.mkString(", ")}" Review Comment: ```suggestion unaryMinus.child.dataType.typeName.mkString(", ") ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DelegatesResolutionToExtensions.scala: ########## @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.errors.QueryCompilationErrors + +/** + * The [[DelegatesResolutionToExtensions]] is a trait which provides a method to delegate the + * resolution of unresolved operators to a list of [[ResolverExtension]]s. + */ +trait DelegatesResolutionToExtensions { + + protected val extensions: Seq[ResolverExtension] + + /** + * Find the suitable extension for `unresolvedOperator` resolution and resolve it with that + * extension. Usually extensions return resolved relation nodes, so we generically update the name + * scope without matching for specific relations, for simplicity. + * + * We match the extension once to reduce the number of + * [[ResolverExtension.resolveOperator.isDefinedAt]] calls, because those can be expensive. + * + * Returns `Some(resolutionResult)` if the extension was found and `unresolvedOperator` was + * resolved, `None` otherwise. + * + * Throws `AMBIGUOUS_RESOLVER_EXTENSION` if there were several matched extensions for this Review Comment: you could use tags here: ```suggestion * @return `Some(resolutionResult)` if the extension was found and `unresolvedOperator` was * resolved, `None` otherwise. * * @throws `AMBIGUOUS_RESOLVER_EXTENSION` if there were several matched extensions for this ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + Review Comment: ```suggestion "Input types of rpad must be valid, but got: " + ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} Review Comment: what's the purpose to invoke this empty method? ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) Review Comment: ```suggestion projectList.foreach { case attributeReference: AttributeReference => validateAttributeReference(attributeReference) case alias: Alias => validateAlias(alias) } ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + Review Comment: ```suggestion "Input types of a binary expression must be valid, but got: " + ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + + s"${unaryMinus.child.dataType.typeName.mkString(", ")}" + ) + } + + private def validateGetStructField(getStructField: GetStructField): Unit = { + validate(getStructField.child) + } + + private def validateCreateNamedStruct(createNamedStruct: CreateNamedStruct): Unit = { + createNamedStruct.children.foreach(validate) + assert( + createNamedStruct.checkInputDataTypes().isSuccess, + s"Input types of CreateNamedStruct must be valid, but got: " + + s"${createNamedStruct.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateGetArrayStructFields(getArrayStructFields: GetArrayStructFields): Unit = { + validate(getArrayStructFields.child) + } + + private def validateGetMapValue(getMapValue: GetMapValue): Unit = { + validate(getMapValue.child) + validate(getMapValue.key) + assert( + getMapValue.checkInputDataTypes().isSuccess, + s"Input types of GetMapValue must be valid, but got: " + + s"${getMapValue.children.map(_.dataType.typeName).mkString(", ")}" Review Comment: ```suggestion "Input types of GetMapValue must be valid, but got: " + getMapValue.children.map(_.dataType.typeName).mkString(", ") ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + + s"${unaryMinus.child.dataType.typeName.mkString(", ")}" + ) + } + + private def validateGetStructField(getStructField: GetStructField): Unit = { + validate(getStructField.child) + } + + private def validateCreateNamedStruct(createNamedStruct: CreateNamedStruct): Unit = { + createNamedStruct.children.foreach(validate) + assert( + createNamedStruct.checkInputDataTypes().isSuccess, + s"Input types of CreateNamedStruct must be valid, but got: " + + s"${createNamedStruct.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateGetArrayStructFields(getArrayStructFields: GetArrayStructFields): Unit = { + validate(getArrayStructFields.child) + } + + private def validateGetMapValue(getMapValue: GetMapValue): Unit = { + validate(getMapValue.child) + validate(getMapValue.key) + assert( + getMapValue.checkInputDataTypes().isSuccess, + s"Input types of GetMapValue must be valid, but got: " + + s"${getMapValue.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateCreateMap(createMap: CreateMap): Unit = { + createMap.children.foreach(validate) + assert( + createMap.checkInputDataTypes().isSuccess, + s"Input types of CreateMap must be valid, but got: " + + s"${createMap.children.map(_.dataType.typeName).mkString(", ")}" Review Comment: Fix here and in the same way below: ```suggestion "Input types of CreateMap must be valid, but got: " + createMap.children.map(_.dataType.typeName).mkString(", ") ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + + s"${unaryMinus.child.dataType.typeName.mkString(", ")}" + ) + } + + private def validateGetStructField(getStructField: GetStructField): Unit = { + validate(getStructField.child) + } + + private def validateCreateNamedStruct(createNamedStruct: CreateNamedStruct): Unit = { + createNamedStruct.children.foreach(validate) + assert( + createNamedStruct.checkInputDataTypes().isSuccess, + s"Input types of CreateNamedStruct must be valid, but got: " + + s"${createNamedStruct.children.map(_.dataType.typeName).mkString(", ")}" Review Comment: ```suggestion createNamedStruct.children.map(_.dataType.typeName).mkString(", ") ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" Review Comment: ```suggestion binaryExpression.children.map(_.dataType.typeName).mkString(", ") ``` ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ExpressionResolutionValidator.scala: ########## @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis.resolver + +import org.apache.spark.sql.catalyst.expressions.{ + Alias, + ArrayDistinct, + ArrayInsert, + ArrayJoin, + ArrayMax, + ArrayMin, + ArraysZip, + AttributeReference, + BinaryExpression, + CreateArray, + CreateMap, + CreateNamedStruct, + Expression, + ExtractANSIIntervalDays, + GetArrayStructFields, + GetMapValue, + GetStructField, + Literal, + MapConcat, + MapContainsKey, + MapEntries, + MapFromEntries, + MapKeys, + MapValues, + NamedExpression, + Predicate, + RuntimeReplaceable, + StringRPad, + StringToMap, + TimeZoneAwareExpression, + UnaryMinus +} +import org.apache.spark.sql.types.BooleanType + +/** + * The [[ExpressionResolutionValidator]] performs the validation work on the expression tree for the + * [[ResolutionValidator]]. These two components work together recursively validating the + * logical plan. You can find more info in the [[ResolutionValidator]] scaladoc. + */ +class ExpressionResolutionValidator(resolutionValidator: ResolutionValidator) { + + /** + * Validate resolved expression tree. The principle is the same as + * [[ResolutionValidator.validate]]. + */ + def validate(expression: Expression): Unit = { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + case getMapValue: GetMapValue => + validateGetMapValue(getMapValue) + case binaryExpression: BinaryExpression => + validateBinaryExpression(binaryExpression) + case extractANSIIntervalDay: ExtractANSIIntervalDays => + validateExtractANSIIntervalDays(extractANSIIntervalDay) + case literal: Literal => + validateLiteral(literal) + case predicate: Predicate => + validatePredicate(predicate) + case stringRPad: StringRPad => + validateStringRPad(stringRPad) + case unaryMinus: UnaryMinus => + validateUnaryMinus(unaryMinus) + case getStructField: GetStructField => + validateGetStructField(getStructField) + case createNamedStruct: CreateNamedStruct => + validateCreateNamedStruct(createNamedStruct) + case getArrayStructFields: GetArrayStructFields => + validateGetArrayStructFields(getArrayStructFields) + case createMap: CreateMap => + validateCreateMap(createMap) + case stringToMap: StringToMap => + validateStringToMap(stringToMap) + case mapContainsKey: MapContainsKey => + validateMapContainsKey(mapContainsKey) + case mapConcat: MapConcat => + validateMapConcat(mapConcat) + case mapKeys: MapKeys => + validateMapKeys(mapKeys) + case mapValues: MapValues => + validateMapValues(mapValues) + case mapEntries: MapEntries => + validateMapEntries(mapEntries) + case mapFromEntries: MapFromEntries => + validateMapFromEntries(mapFromEntries) + case createArray: CreateArray => + validateCreateArray(createArray) + case arrayDistinct: ArrayDistinct => + validateArrayDistinct(arrayDistinct) + case arrayInsert: ArrayInsert => + validateArrayInsert(arrayInsert) + case arrayJoin: ArrayJoin => + validateArrayJoin(arrayJoin) + case arrayMax: ArrayMax => + validateArrayMax(arrayMax) + case arrayMin: ArrayMin => + validateArrayMin(arrayMin) + case arraysZip: ArraysZip => + validateArraysZip(arraysZip) + case runtimeReplaceable: RuntimeReplaceable => + validateRuntimeReplaceable(runtimeReplaceable) + case timezoneExpression: TimeZoneAwareExpression => + validateTimezoneExpression(timezoneExpression) + } + } + + def validateProjectList(projectList: Seq[NamedExpression]): Unit = { + projectList.foreach(expression => { + expression match { + case attributeReference: AttributeReference => + validateAttributeReference(attributeReference) + case alias: Alias => + validateAlias(alias) + } + }) + } + + private def validatePredicate(predicate: Predicate) = { + predicate.children.foreach(validate) + assert( + predicate.dataType == BooleanType, + s"Output type of a predicate must be a boolean, but got: ${predicate.dataType.typeName}" + ) + assert( + predicate.checkInputDataTypes().isSuccess, + s"Input types of a predicate must be valid, but got: " + + s"${predicate.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateStringRPad(stringRPad: StringRPad) = { + validate(stringRPad.first) + validate(stringRPad.second) + validate(stringRPad.third) + assert( + stringRPad.checkInputDataTypes().isSuccess, + s"Input types of rpad must be valid, but got: " + + s"${stringRPad.children.map(_.dataType.typeName).mkString(", ")}" + ) + } + + private def validateAttributeReference(attributeReference: AttributeReference): Unit = { + assert( + resolutionValidator.attributeScopeStack.top.contains(attributeReference), + s"Attribute $attributeReference is missing from attribute scope: " + + s"${resolutionValidator.attributeScopeStack.top}" + ) + } + + private def validateAlias(alias: Alias): Unit = { + validate(alias.child) + } + + private def validateBinaryExpression(binaryExpression: BinaryExpression): Unit = { + validate(binaryExpression.left) + validate(binaryExpression.right) + assert( + binaryExpression.checkInputDataTypes().isSuccess, + s"Input types of a binary expression must be valid, but got: " + + s"${binaryExpression.children.map(_.dataType.typeName).mkString(", ")}" + ) + + binaryExpression match { + case timezoneExpression: TimeZoneAwareExpression => + assert(timezoneExpression.timeZoneId.nonEmpty, "Timezone expression must have a timezone") + case _ => + } + } + + private def validateExtractANSIIntervalDays( + extractANSIIntervalDays: ExtractANSIIntervalDays): Unit = { + validate(extractANSIIntervalDays.child) + } + + private def validateLiteral(literal: Literal): Unit = {} + + private def validateUnaryMinus(unaryMinus: UnaryMinus): Unit = { + validate(unaryMinus.child) + assert( + unaryMinus.checkInputDataTypes().isSuccess, + s"Input types of a unary minus must be valid, but got: " + + s"${unaryMinus.child.dataType.typeName.mkString(", ")}" + ) + } + + private def validateGetStructField(getStructField: GetStructField): Unit = { + validate(getStructField.child) + } + + private def validateCreateNamedStruct(createNamedStruct: CreateNamedStruct): Unit = { + createNamedStruct.children.foreach(validate) + assert( + createNamedStruct.checkInputDataTypes().isSuccess, + s"Input types of CreateNamedStruct must be valid, but got: " + Review Comment: ```suggestion "Input types of CreateNamedStruct must be valid, but got: " + ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
