fqaiser94 commented on a change in pull request #29795:
URL: https://github.com/apache/spark/pull/29795#discussion_r494698625
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
##########
@@ -541,57 +541,105 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
}
/**
- * Adds/replaces field in struct by name.
+ * Represents an operation to be applied to the fields of a struct.
*/
-case class WithFields(
- structExpr: Expression,
- names: Seq[String],
- valExprs: Seq[Expression]) extends Unevaluable {
+trait StructFieldsOperation {
- assert(names.length == valExprs.length)
+ val resolver: Resolver = SQLConf.get.resolver
+
+ /**
+ * Returns an updated list of StructFields and Expressions that will
ultimately be used
+ * as the fields argument for [[StructType]] and as the children argument for
+ * [[CreateNamedStruct]] respectively inside of [[UpdateFields]].
+ */
+ def apply(values: Seq[(StructField, Expression)]): Seq[(StructField,
Expression)]
+}
+
+/**
+ * Add or replace a field by name.
+ *
+ * We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include
it as part of its
+ * children, and thereby enable the analyzer to resolve and transform valExpr
as necessary.
+ */
+case class WithField(name: String, valExpr: Expression)
+ extends Unevaluable with StructFieldsOperation {
+
+ override def apply(values: Seq[(StructField, Expression)]):
Seq[(StructField, Expression)] = {
+ val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable),
valExpr)
+ if (values.exists { case (field, _) => resolver(field.name, name) }) {
+ values.map {
+ case (field, _) if resolver(field.name, name) => newFieldExpr
+ case x => x
+ }
+ } else {
+ values :+ newFieldExpr
+ }
+ }
+
+ override def children: Seq[Expression] = valExpr :: Nil
+
+ override def dataType: DataType = throw new UnresolvedException(this,
"dataType")
Review comment:
done
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
##########
@@ -541,57 +541,105 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
}
/**
- * Adds/replaces field in struct by name.
+ * Represents an operation to be applied to the fields of a struct.
*/
-case class WithFields(
- structExpr: Expression,
- names: Seq[String],
- valExprs: Seq[Expression]) extends Unevaluable {
+trait StructFieldsOperation {
- assert(names.length == valExprs.length)
+ val resolver: Resolver = SQLConf.get.resolver
+
+ /**
+ * Returns an updated list of StructFields and Expressions that will
ultimately be used
+ * as the fields argument for [[StructType]] and as the children argument for
+ * [[CreateNamedStruct]] respectively inside of [[UpdateFields]].
+ */
+ def apply(values: Seq[(StructField, Expression)]): Seq[(StructField,
Expression)]
+}
+
+/**
+ * Add or replace a field by name.
+ *
+ * We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include
it as part of its
+ * children, and thereby enable the analyzer to resolve and transform valExpr
as necessary.
+ */
+case class WithField(name: String, valExpr: Expression)
+ extends Unevaluable with StructFieldsOperation {
+
+ override def apply(values: Seq[(StructField, Expression)]):
Seq[(StructField, Expression)] = {
+ val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable),
valExpr)
+ if (values.exists { case (field, _) => resolver(field.name, name) }) {
Review comment:
thanks for sharing the code, done
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
##########
@@ -541,57 +541,105 @@ case class StringToMap(text: Expression, pairDelim:
Expression, keyValueDelim: E
}
/**
- * Adds/replaces field in struct by name.
+ * Represents an operation to be applied to the fields of a struct.
*/
-case class WithFields(
- structExpr: Expression,
- names: Seq[String],
- valExprs: Seq[Expression]) extends Unevaluable {
+trait StructFieldsOperation {
- assert(names.length == valExprs.length)
+ val resolver: Resolver = SQLConf.get.resolver
+
+ /**
+ * Returns an updated list of StructFields and Expressions that will
ultimately be used
+ * as the fields argument for [[StructType]] and as the children argument for
+ * [[CreateNamedStruct]] respectively inside of [[UpdateFields]].
+ */
+ def apply(values: Seq[(StructField, Expression)]): Seq[(StructField,
Expression)]
+}
+
+/**
+ * Add or replace a field by name.
+ *
+ * We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include
it as part of its
+ * children, and thereby enable the analyzer to resolve and transform valExpr
as necessary.
+ */
+case class WithField(name: String, valExpr: Expression)
+ extends Unevaluable with StructFieldsOperation {
+
+ override def apply(values: Seq[(StructField, Expression)]):
Seq[(StructField, Expression)] = {
+ val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable),
valExpr)
+ if (values.exists { case (field, _) => resolver(field.name, name) }) {
+ values.map {
+ case (field, _) if resolver(field.name, name) => newFieldExpr
+ case x => x
+ }
+ } else {
+ values :+ newFieldExpr
+ }
+ }
+
+ override def children: Seq[Expression] = valExpr :: Nil
+
+ override def dataType: DataType = throw new UnresolvedException(this,
"dataType")
+
+ override def nullable: Boolean = throw new UnresolvedException(this,
"nullable")
+
+ override def prettyName: String = "WithField"
+}
+
+/**
+ * Drop a field by name.
+ */
+case class DropField(name: String) extends StructFieldsOperation {
+ override def apply(values: Seq[(StructField, Expression)]):
Seq[(StructField, Expression)] =
+ values.filterNot { case (field, _) => resolver(field.name, name) }
+}
+
+/**
+ * Updates fields in a struct.
+ */
+case class UpdateFields(structExpr: Expression, fieldOps:
Seq[StructFieldsOperation])
+ extends Unevaluable {
override def checkInputDataTypes(): TypeCheckResult = {
- if (!structExpr.dataType.isInstanceOf[StructType]) {
- TypeCheckResult.TypeCheckFailure(
- "struct argument should be struct type, got: " +
structExpr.dataType.catalogString)
+ val dataType = structExpr.dataType
+ if (!dataType.isInstanceOf[StructType]) {
+ TypeCheckResult.TypeCheckFailure("struct argument should be struct type,
got: " +
+ dataType.catalogString)
+ } else if (newExprs.isEmpty) {
+ TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
- override def children: Seq[Expression] = structExpr +: valExprs
+ override def children: Seq[Expression] = structExpr +: fieldOps.collect {
+ case e: Expression => e
+ }
- override def dataType: StructType =
evalExpr.dataType.asInstanceOf[StructType]
+ override def dataType: StructType = StructType(newFields)
override def nullable: Boolean = structExpr.nullable
- override def prettyName: String = "with_fields"
+ override def prettyName: String = "update_fields"
- lazy val evalExpr: Expression = {
- val existingExprs =
structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
- case (name, i) => (name, GetStructField(KnownNotNull(structExpr),
i).asInstanceOf[Expression])
+ private lazy val existingFieldExprs: Seq[(StructField, Expression)] =
+ structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map {
+ case (field, i) => (field, GetStructField(structExpr, i))
}
- val addOrReplaceExprs = names.zip(valExprs)
-
- val resolver = SQLConf.get.resolver
- val newExprs = addOrReplaceExprs.foldLeft(existingExprs) {
- case (resultExprs, newExpr @ (newExprName, _)) =>
- if (resultExprs.exists(x => resolver(x._1, newExprName))) {
- resultExprs.map {
- case (name, _) if resolver(name, newExprName) => newExpr
- case x => x
- }
- } else {
- resultExprs :+ newExpr
- }
- }.flatMap { case (name, expr) => Seq(Literal(name), expr) }
+ private lazy val newFieldExprs: Seq[(StructField, Expression)] =
+ fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs))
- val expr = CreateNamedStruct(newExprs)
- if (structExpr.nullable) {
- If(IsNull(structExpr), Literal(null, expr.dataType), expr)
- } else {
- expr
- }
+ private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1)
+
+ lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2)
+
+ private lazy val createNamedStructExpr =
CreateNamedStruct(newFieldExprs.flatMap {
+ case (field, expr) => Seq(Literal(field.name), expr)
+ })
+
+ lazy val evalExpr: Expression = if (structExpr.nullable) {
Review comment:
done
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
* }}}
*
+ * This method supports adding/replacing nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d",
lit(4)))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * However, if you are going to add/replace multiple nested fields, it is
more optimal to extract
+ * out the nested struct before adding/replacing multiple fields e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".withField("c",
lit(3)).withField("d", lit(4))))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
* @group expr_ops
* @since 3.1.0
*/
// scalastyle:on line.size.limit
def withField(fieldName: String, col: Column): Column = withExpr {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")
+ updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name,
col.expr))
+ }
- val nameParts = if (fieldName.isEmpty) {
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that drops fields in `StructType` by name.
Review comment:
I've made this change but now that I think about it, I don't think its
actually classifies as a "noop". We still reconstruct the struct unfortunately
e.g.
```
val structType = StructType(Seq(
StructField("a", IntegerType, nullable = false),
StructField("b", IntegerType, nullable = true),
StructField("c", IntegerType, nullable = false)))
val structLevel1: DataFrame = spark.createDataFrame(
sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil),
StructType(Seq(StructField("a", structType, nullable = false))))
structLevel1.withColumn("a", 'a.dropFields("d")).explain()
== Physical Plan ==
*(1) Project [named_struct(a, a#1.a, b, a#1.b, c, a#1.c) AS a#3]
+- *(1) Scan ExistingRDD[a#1]
```
Should I revert this?
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
* }}}
*
+ * This method supports adding/replacing nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d",
lit(4)))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * However, if you are going to add/replace multiple nested fields, it is
more optimal to extract
+ * out the nested struct before adding/replacing multiple fields e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".withField("c",
lit(3)).withField("d", lit(4))))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
* @group expr_ops
* @since 3.1.0
*/
// scalastyle:on line.size.limit
def withField(fieldName: String, col: Column): Column = withExpr {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")
+ updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name,
col.expr))
+ }
- val nameParts = if (fieldName.isEmpty) {
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that drops fields in `StructType` by name.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("c"))
+ * // result: {"a":1,"b":2}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+ * df.select($"struct_col".dropFields("b", "c"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("a", "b"))
+ * // result: org.apache.spark.sql.AnalysisException: cannot resolve
'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot
drop all fields in struct
+ *
+ * val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: null of type struct<a:int>
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".dropFields("a.b"))
+ * // result: {"a":{"a":1}}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a',
named_struct('c', 2)) struct_col")
+ * df.select($"struct_col".dropFields("a.c"))
+ * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
+ * }}}
+ *
+ * This method supports dropping multiple nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".dropFields("a.b", "a.c"))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * However, if you are going to drop multiple nested fields, it is more
optimal to extract
+ * out the nested struct before dropping multiple fields from it e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b",
"c")))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.1.0
+ */
+ // scalastyle:on line.size.limit
+ def dropFields(fieldNames: String*): Column = withExpr {
+ def dropField(structExpr: Expression, fieldName: String): UpdateFields =
+ updateFieldsHelper(structExpr, nameParts(fieldName), name =>
DropField(name))
+
+ fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) {
+ (resExpr, fieldName) => dropField(resExpr, fieldName)
+ }
+ }
+
+ private def nameParts(fieldName: String): Seq[String] = {
+ require(fieldName != null, "fieldName cannot be null")
+
+ if (fieldName.isEmpty) {
fieldName :: Nil
Review comment:
we've discussed this before
[here](https://github.com/apache/spark/pull/27066#discussion_r448416127) :)
Its needed for `withField` and I think we should support it in `dropFields`
as well because `Dataset.drop` supports it:
```
scala> Seq((1, 2)).toDF("a", "").drop("").printSchema
root
|-- a: integer (nullable = false)
```
I've added a test case to demonstrate this works on the `dropFields` side
but otherwise left the code unchanged.
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/Column.scala
##########
@@ -901,39 +901,125 @@ class Column(val expr: Expression) extends Logging {
* // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
* }}}
*
+ * This method supports adding/replacing nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d",
lit(4)))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
+ * However, if you are going to add/replace multiple nested fields, it is
more optimal to extract
+ * out the nested struct before adding/replacing multiple fields e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".withField("c",
lit(3)).withField("d", lit(4))))
+ * // result: {"a":{"a":1,"b":2,"c":3,"d":4}}
+ * }}}
+ *
* @group expr_ops
* @since 3.1.0
*/
// scalastyle:on line.size.limit
def withField(fieldName: String, col: Column): Column = withExpr {
require(fieldName != null, "fieldName cannot be null")
require(col != null, "col cannot be null")
+ updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name,
col.expr))
+ }
- val nameParts = if (fieldName.isEmpty) {
+ // scalastyle:off line.size.limit
+ /**
+ * An expression that drops fields in `StructType` by name.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("c"))
+ * // result: {"a":1,"b":2}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col")
+ * df.select($"struct_col".dropFields("b", "c"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col")
+ * df.select($"struct_col".dropFields("a", "b"))
+ * // result: org.apache.spark.sql.AnalysisException: cannot resolve
'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot
drop all fields in struct
+ *
+ * val df = sql("SELECT CAST(NULL AS struct<a:int,b:int>) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: null of type struct<a:int>
+ *
+ * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col")
+ * df.select($"struct_col".dropFields("b"))
+ * // result: {"a":1}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".dropFields("a.b"))
+ * // result: {"a":{"a":1}}
+ *
+ * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a',
named_struct('c', 2)) struct_col")
+ * df.select($"struct_col".dropFields("a.c"))
+ * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference
to fields
+ * }}}
+ *
+ * This method supports dropping multiple nested fields directly e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".dropFields("a.b", "a.c"))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * However, if you are going to drop multiple nested fields, it is more
optimal to extract
+ * out the nested struct before dropping multiple fields from it e.g.
+ *
+ * {{{
+ * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2))
struct_col")
+ * df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b",
"c")))
+ * // result: {"a":{"a":1}}
+ * }}}
+ *
+ * @group expr_ops
+ * @since 3.1.0
+ */
+ // scalastyle:on line.size.limit
+ def dropFields(fieldNames: String*): Column = withExpr {
+ def dropField(structExpr: Expression, fieldName: String): UpdateFields =
+ updateFieldsHelper(structExpr, nameParts(fieldName), name =>
DropField(name))
+
+ fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) {
+ (resExpr, fieldName) => dropField(resExpr, fieldName)
+ }
+ }
+
+ private def nameParts(fieldName: String): Seq[String] = {
+ require(fieldName != null, "fieldName cannot be null")
+
+ if (fieldName.isEmpty) {
fieldName :: Nil
} else {
CatalystSqlParser.parseMultipartIdentifier(fieldName)
}
- withFieldHelper(expr, nameParts, Nil, col.expr)
}
- private def withFieldHelper(
- struct: Expression,
- namePartsRemaining: Seq[String],
- namePartsDone: Seq[String],
- value: Expression) : WithFields = {
- val name = namePartsRemaining.head
+ private def updateFieldsHelper(
+ structExpr: Expression,
Review comment:
done
##########
File path: sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
##########
@@ -159,6 +160,14 @@ abstract class QueryTest extends PlanTest {
checkAnswer(df, expectedAnswer.collect())
}
+ protected def checkAnswer(
+ df: => DataFrame,
+ expectedAnswer: Seq[Row],
Review comment:
done
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+ private def nestedColName(d: Int, colNum: Int): String =
s"nested${d}Col$colNum"
+
+ private def nestedStructType(
+ depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {
Review comment:
done
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+ private def nestedColName(d: Int, colNum: Int): String =
s"nested${d}Col$colNum"
+
+ private def nestedStructType(
+ depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {
+ if (depths.length == 1) {
+ StructType(colNums.map { colNum =>
+ StructField(nestedColName(depths.head, colNum), IntegerType, nullable
= false)
+ })
+ } else {
+ val depth = depths.head
+ val fields = colNums.foldLeft(Seq.empty[StructField]) {
+ case (structFields, colNum) if colNum == 0 =>
+ val nested = nestedStructType(depths.tail, colNums, nullable)
+ structFields :+ StructField(nestedColName(depth, colNum), nested,
nullable)
+ case (structFields, colNum) =>
+ structFields :+ StructField(nestedColName(depth, colNum),
IntegerType, nullable = false)
+ }
+ StructType(fields)
+ }
+ }
+
+ private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = {
+ if (depths.length == 1) {
+ Row.fromSeq(colNums)
+ } else {
+ val values = colNums.foldLeft(Seq.empty[Any]) {
+ case (values, colNum) if colNum == 0 => values :+
nestedRow(depths.tail, colNums)
+ case (values, colNum) => values :+ colNum
+ }
+ Row.fromSeq(values)
+ }
+ }
+
+ /**
+ * Utility function for generating a DataFrame with nested columns.
+ *
+ * @param depth: The depth to which to create nested columns.
+ * @param numColsAtEachDepth: The number of columns to create at each depth.
The value of each
+ * column will be the same as its index
(IntegerType) at that depth
+ * unless the index = 0, in which case it may be a
StructType which
+ * represents the next depth.
+ * @param nullable: This value is used to set the nullability of StructType
columns.
+ */
+ private def nestedDf(
+ depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame
= {
Review comment:
done
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
+
+ private def nestedColName(d: Int, colNum: Int): String =
s"nested${d}Col$colNum"
+
+ private def nestedStructType(
+ depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = {
+ if (depths.length == 1) {
+ StructType(colNums.map { colNum =>
+ StructField(nestedColName(depths.head, colNum), IntegerType, nullable
= false)
+ })
+ } else {
+ val depth = depths.head
+ val fields = colNums.foldLeft(Seq.empty[StructField]) {
+ case (structFields, colNum) if colNum == 0 =>
+ val nested = nestedStructType(depths.tail, colNums, nullable)
+ structFields :+ StructField(nestedColName(depth, colNum), nested,
nullable)
+ case (structFields, colNum) =>
+ structFields :+ StructField(nestedColName(depth, colNum),
IntegerType, nullable = false)
+ }
+ StructType(fields)
+ }
+ }
+
+ private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = {
+ if (depths.length == 1) {
+ Row.fromSeq(colNums)
+ } else {
+ val values = colNums.foldLeft(Seq.empty[Any]) {
+ case (values, colNum) if colNum == 0 => values :+
nestedRow(depths.tail, colNums)
+ case (values, colNum) => values :+ colNum
+ }
+ Row.fromSeq(values)
+ }
+ }
+
+ /**
+ * Utility function for generating a DataFrame with nested columns.
+ *
+ * @param depth: The depth to which to create nested columns.
+ * @param numColsAtEachDepth: The number of columns to create at each depth.
The value of each
+ * column will be the same as its index
(IntegerType) at that depth
+ * unless the index = 0, in which case it may be a
StructType which
+ * represents the next depth.
+ * @param nullable: This value is used to set the nullability of StructType
columns.
+ */
+ private def nestedDf(
+ depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame
= {
+ require(depth > 0)
+ require(numColsAtEachDepth > 0)
+
+ val depths = 1 to depth
+ val colNums = 0 until numColsAtEachDepth
+ val nestedColumn = nestedRow(depths, colNums)
+ val nestedColumnDataType = nestedStructType(depths, colNums, nullable)
+
+ spark.createDataFrame(
+ sparkContext.parallelize(Row(nestedColumn) :: Nil),
+ StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType,
nullable))))
+ }
+
+ test("nestedDf should generate nested DataFrames") {
+ checkAnswer(
+ nestedDf(1, 1),
+ Row(Row(0)) :: Nil,
+ StructType(Seq(StructField("nested0Col0", StructType(Seq(
+ StructField("nested1Col0", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ nestedDf(1, 2),
+ Row(Row(0, 1)) :: Nil,
+ StructType(Seq(StructField("nested0Col0", StructType(Seq(
+ StructField("nested1Col0", IntegerType, nullable = false),
+ StructField("nested1Col1", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ nestedDf(2, 1),
+ Row(Row(Row(0))) :: Nil,
+ StructType(Seq(StructField("nested0Col0", StructType(Seq(
+ StructField("nested1Col0", StructType(Seq(
+ StructField("nested2Col0", IntegerType, nullable = false))),
+ nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ nestedDf(2, 2),
+ Row(Row(Row(0, 1), 1)) :: Nil,
+ StructType(Seq(StructField("nested0Col0", StructType(Seq(
+ StructField("nested1Col0", StructType(Seq(
+ StructField("nested2Col0", IntegerType, nullable = false),
+ StructField("nested2Col1", IntegerType, nullable = false))),
+ nullable = false),
+ StructField("nested1Col1", IntegerType, nullable = false))),
+ nullable = false))))
+
+ checkAnswer(
+ nestedDf(2, 2, nullable = true),
+ Row(Row(Row(0, 1), 1)) :: Nil,
+ StructType(Seq(StructField("nested0Col0", StructType(Seq(
+ StructField("nested1Col0", StructType(Seq(
+ StructField("nested2Col0", IntegerType, nullable = false),
+ StructField("nested2Col1", IntegerType, nullable = false))),
+ nullable = true),
+ StructField("nested1Col1", IntegerType, nullable = false))),
+ nullable = true))))
+ }
+
+ // simulates how a user would add/drop nested fields in a performant manner
+ private def addDropNestedColumns(
+ column: Column,
Review comment:
done
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala
##########
@@ -0,0 +1,229 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+
+class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession {
Review comment:
done, please take a closer look at this as this was relatively new to
me.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]