fqaiser94 commented on a change in pull request #29322:
URL: https://github.com/apache/spark/pull/29322#discussion_r469535097
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
##########
@@ -541,57 +541,94 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
}
/**
- * Adds/replaces field in struct by name.
+ * Represents an operation to be applied to the fields of a struct.
*/
-case class WithFields(
- structExpr: Expression,
- names: Seq[String],
- valExprs: Seq[Expression]) extends Unevaluable {
+trait StructFieldsOperation {
- assert(names.length == valExprs.length)
+ val resolver: Resolver = SQLConf.get.resolver
+
+ /**
+ * Returns an updated list of expressions which will ultimately be used as
the children argument
+ * for [[CreateNamedStruct]].
+ */
+ def apply(exprs: Seq[(String, Expression)]): Seq[(String, Expression)]
+}
+
+/**
+ * Add or replace a field by name.
+ */
+case class WithField(name: String, valExpr: Expression)
+ extends Unevaluable with StructFieldsOperation {
+
+ override def apply(exprs: Seq[(String, Expression)]): Seq[(String,
Expression)] =
+ if (exprs.exists(x => resolver(x._1, name))) {
+ exprs.map {
+ case (existingName, _) if resolver(existingName, name) => (name,
valExpr)
+ case x => x
+ }
+ } else {
+ exprs :+ (name, valExpr)
+ }
+
+ override def children: Seq[Expression] = valExpr :: Nil
+
+ override def dataType: DataType = throw new UnresolvedException(this,
"dataType")
+
+ override def nullable: Boolean = throw new UnresolvedException(this,
"nullable")
+
+ override def prettyName: String = "WithField"
+}
+
+/**
+ * Drop a field by name.
+ */
+case class DropField(name: String) extends StructFieldsOperation {
+ override def apply(exprs: Seq[(String, Expression)]): Seq[(String,
Expression)] =
+ exprs.filterNot(expr => resolver(expr._1, name))
+}
+
+/**
+ * Updates fields in struct by name.
+ */
+case class UpdateFields(structExpr: Expression, fieldOps:
Seq[StructFieldsOperation])
+ extends Unevaluable {
override def checkInputDataTypes(): TypeCheckResult = {
- if (!structExpr.dataType.isInstanceOf[StructType]) {
- TypeCheckResult.TypeCheckFailure(
- "struct argument should be struct type, got: " +
structExpr.dataType.catalogString)
+ val dataType = structExpr.dataType
+ if (!dataType.isInstanceOf[StructType]) {
+ TypeCheckResult.TypeCheckFailure("struct argument should be struct type,
got: " +
+ dataType.catalogString)
+ } else if (newExprs.isEmpty) {
+ TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
- override def children: Seq[Expression] = structExpr +: valExprs
+ override def children: Seq[Expression] = structExpr +: fieldOps.collect {
+ case e: Expression => e
Review comment:
This won't work. A working alternative is `case w: WithField => w` but I
would prefer to leave it as-is (i.e. `case e: Expression => e`) because it is
more future-proof. Let me know though if you don't think it's worth it to
consider the future here for now.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]