This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ce1fca991860 [SPARK-49420][CONNECT][SQL] Add shared interface for
DataFrameNaFunctions
ce1fca991860 is described below
commit ce1fca991860916cef207c5c3d98bb7074e0d3f0
Author: Herman van Hovell <[email protected]>
AuthorDate: Tue Sep 3 00:46:33 2024 -0400
[SPARK-49420][CONNECT][SQL] Add shared interface for DataFrameNaFunctions
### What changes were proposed in this pull request?
This PR creates a shared interface for DataFrameNaFunctions.
### Why are the changes needed?
We are creating a shared Scala SQL API interface. This class is part of
this interface.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47961 from hvanhovell/SPARK-49420.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../apache/spark/sql/DataFrameNaFunctions.scala | 376 ++++----------------
.../main/scala/org/apache/spark/sql/Dataset.scala | 22 +-
.../spark/sql/api}/DataFrameNaFunctions.scala | 283 +++------------
.../scala/org/apache/spark/sql/api/Dataset.scala | 12 +
.../apache/spark/sql/DataFrameNaFunctions.scala | 378 ++++-----------------
.../main/scala/org/apache/spark/sql/Dataset.scala | 22 +-
6 files changed, 206 insertions(+), 887 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 7d484d82ec25..c06cbbc0cdb4 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -17,118 +17,26 @@
package org.apache.spark.sql
-import java.util.Locale
-
import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{NAReplace, Relation}
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
import org.apache.spark.connect.proto.NAReplace.Replacement
-import org.apache.spark.util.ArrayImplicits._
/**
* Functionality for working with missing data in `DataFrame`s.
*
* @since 3.4.0
*/
-final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession,
root: Relation) {
+final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession,
root: Relation)
+ extends api.DataFrameNaFunctions[Dataset] {
import sparkSession.RichColumn
- /**
- * Returns a new `DataFrame` that drops rows containing any null or NaN
values.
- *
- * @since 3.4.0
- */
- def drop(): DataFrame = buildDropDataFrame(None, None)
-
- /**
- * Returns a new `DataFrame` that drops rows containing null or NaN values.
- *
- * If `how` is "any", then drop rows containing any null or NaN values. If
`how` is "all", then
- * drop rows only if every column is null or NaN for that row.
- *
- * @since 3.4.0
- */
- def drop(how: String): DataFrame = {
- buildDropDataFrame(None, buildMinNonNulls(how))
- }
-
- /**
- * Returns a new `DataFrame` that drops rows containing any null or NaN
values in the specified
- * columns.
- *
- * @since 3.4.0
- */
- def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing any
null or NaN values
- * in the specified columns.
- *
- * @since 3.4.0
- */
- def drop(cols: Seq[String]): DataFrame = buildDropDataFrame(Some(cols), None)
-
- /**
- * Returns a new `DataFrame` that drops rows containing null or NaN values
in the specified
- * columns.
- *
- * If `how` is "any", then drop rows containing any null or NaN values in
the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null
or NaN for that row.
- *
- * @since 3.4.0
- */
- def drop(how: String, cols: Array[String]): DataFrame = drop(how,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing
null or NaN values in
- * the specified columns.
- *
- * If `how` is "any", then drop rows containing any null or NaN values in
the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null
or NaN for that row.
- *
- * @since 3.4.0
- */
- def drop(how: String, cols: Seq[String]): DataFrame = {
- buildDropDataFrame(Some(cols), buildMinNonNulls(how))
- }
-
- /**
- * Returns a new `DataFrame` that drops rows containing less than
`minNonNulls` non-null and
- * non-NaN values.
- *
- * @since 3.4.0
- */
- def drop(minNonNulls: Int): DataFrame = {
- buildDropDataFrame(None, Some(minNonNulls))
- }
+ override protected def drop(minNonNulls: Option[Int]): Dataset[Row] =
+ buildDropDataFrame(None, minNonNulls)
- /**
- * Returns a new `DataFrame` that drops rows containing less than
`minNonNulls` non-null and
- * non-NaN values in the specified columns.
- *
- * @since 3.4.0
- */
- def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
- drop(minNonNulls, cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing
less than `minNonNulls`
- * non-null and non-NaN values in the specified columns.
- *
- * @since 3.4.0
- */
- def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
- buildDropDataFrame(Some(cols), Some(minNonNulls))
- }
-
- private def buildMinNonNulls(how: String): Option[Int] = {
- how.toLowerCase(Locale.ROOT) match {
- case "any" => None // No-Op. Do nothing.
- case "all" => Some(1)
- case _ => throw new IllegalArgumentException(s"how ($how) must be 'any'
or 'all'")
- }
- }
+ override protected def drop(minNonNulls: Option[Int], cols: Seq[String]):
Dataset[Row] =
+ buildDropDataFrame(Option(cols), minNonNulls)
private def buildDropDataFrame(
cols: Option[Seq[String]],
@@ -140,110 +48,42 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
}
}
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Long): DataFrame = {
buildFillDataFrame(None, GLiteral.newBuilder().setLong(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns. If a
- * specified column is not a numeric column, it is ignored.
- *
- * @since 3.4.0
- */
- def fill(value: Long, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
- * numeric columns. If a specified column is not a numeric column, it is
ignored.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Long, cols: Seq[String]): DataFrame = {
buildFillDataFrame(Some(cols),
GLiteral.newBuilder().setLong(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Double): DataFrame = {
buildFillDataFrame(None, GLiteral.newBuilder().setDouble(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns. If a
- * specified column is not a numeric column, it is ignored.
- *
- * @since 3.4.0
- */
- def fill(value: Double, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
- * numeric columns. If a specified column is not a numeric column, it is
ignored.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Double, cols: Seq[String]): DataFrame = {
buildFillDataFrame(Some(cols),
GLiteral.newBuilder().setDouble(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null values in string columns
with `value`.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: String): DataFrame = {
buildFillDataFrame(None, GLiteral.newBuilder().setString(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null values in specified string
columns. If a
- * specified column is not a string column, it is ignored.
- *
- * @since 3.4.0
- */
- def fill(value: String, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in
specified string
- * columns. If a specified column is not a string column, it is ignored.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: String, cols: Seq[String]): DataFrame = {
buildFillDataFrame(Some(cols),
GLiteral.newBuilder().setString(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null values in boolean columns
with `value`.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Boolean): DataFrame = {
buildFillDataFrame(None, GLiteral.newBuilder().setBoolean(value).build())
}
- /**
- * Returns a new `DataFrame` that replaces null values in specified boolean
columns. If a
- * specified column is not a boolean column, it is ignored.
- *
- * @since 3.4.0
- */
- def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in
specified boolean
- * columns. If a specified column is not a boolean column, it is ignored.
- *
- * @since 3.4.0
- */
+ /** @inheritdoc */
def fill(value: Boolean, cols: Seq[String]): DataFrame = {
buildFillDataFrame(Some(cols),
GLiteral.newBuilder().setBoolean(value).build())
}
@@ -256,43 +96,7 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
}
}
- /**
- * Returns a new `DataFrame` that replaces null values.
- *
- * The key of the map is the column name, and the value of the map is the
replacement value. The
- * value must be of the following type: `Integer`, `Long`, `Float`,
`Double`, `String`,
- * `Boolean`. Replacement values are cast to the column data type.
- *
- * For example, the following replaces null values in column "A" with string
"unknown", and null
- * values in column "B" with numeric value 1.0.
- * {{{
- * import com.google.common.collect.ImmutableMap;
- * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
- * }}}
- *
- * @since 3.4.0
- */
- def fill(valueMap: java.util.Map[String, Any]): DataFrame =
fillMap(valueMap.asScala.toSeq)
-
- /**
- * Returns a new `DataFrame` that replaces null values.
- *
- * The key of the map is the column name, and the value of the map is the
replacement value. The
- * value must be of the following type: `Integer`, `Long`, `Float`,
`Double`, `String`,
- * `Boolean`. Replacement values are cast to the column data type.
- *
- * For example, the following replaces null values in column "A" with string
"unknown", and null
- * values in column "B" with numeric value 1.0.
- * {{{
- * import com.google.common.collect.ImmutableMap;
- * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
- * }}}
- *
- * @since 3.4.0
- */
- def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
-
- private def fillMap(values: Seq[(String, Any)]): DataFrame = {
+ protected def fillMap(values: Seq[(String, Any)]): DataFrame = {
sparkSession.newDataFrame { builder =>
val fillNaBuilder = builder.getFillNaBuilder.setInput(root)
values.map { case (colName, replaceValue) =>
@@ -301,104 +105,13 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
}
}
- /**
- * Replaces values matching keys in `replacement` map with the corresponding
values.
- *
- * {{{
- * import com.google.common.collect.ImmutableMap;
- *
- * // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.na.replace("height", ImmutableMap.of(1.0, 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"name".
- * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string
columns.
- * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
- * }}}
- *
- * @param col
- * name of the column to apply the value replacement. If `col` is "*",
replacement is applied
- * on all string, numeric or boolean columns.
- * @param replacement
- * value replacement map. Key and value of `replacement` map must have the
same type, and can
- * only be doubles, strings or booleans. The map value can have nulls.
- * @since 3.4.0
- */
- def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame =
- replace(col, replacement.asScala.toMap)
-
- /**
- * (Scala-specific) Replaces values matching keys in `replacement` map.
- *
- * {{{
- * // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.na.replace("height", Map(1.0 -> 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"name".
- * df.na.replace("name", Map("UNKNOWN" -> "unnamed"));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string
columns.
- * df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
- * }}}
- *
- * @param col
- * name of the column to apply the value replacement. If `col` is "*",
replacement is applied
- * on all string, numeric or boolean columns.
- * @param replacement
- * value replacement map. Key and value of `replacement` map must have the
same type, and can
- * only be doubles, strings or booleans. The map value can have nulls.
- * @since 3.4.0
- */
+ /** @inheritdoc */
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
val cols = if (col != "*") Some(Seq(col)) else None
buildReplaceDataFrame(cols, buildReplacement(replacement))
}
- /**
- * Replaces values matching keys in `replacement` map with the corresponding
values.
- *
- * {{{
- * import com.google.common.collect.ImmutableMap;
- *
- * // Replaces all occurrences of 1.0 with 2.0 in column "height" and
"weight".
- * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0,
2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"firstname" and "lastname".
- * df.na.replace(new String[] {"firstname", "lastname"},
ImmutableMap.of("UNKNOWN", "unnamed"));
- * }}}
- *
- * @param cols
- * list of columns to apply the value replacement. If `col` is "*",
replacement is applied on
- * all string, numeric or boolean columns.
- * @param replacement
- * value replacement map. Key and value of `replacement` map must have the
same type, and can
- * only be doubles, strings or booleans. The map value can have nulls.
- * @since 3.4.0
- */
- def replace[T](cols: Array[String], replacement: java.util.Map[T, T]):
DataFrame = {
- replace(cols.toImmutableArraySeq, replacement.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Replaces values matching keys in `replacement` map.
- *
- * {{{
- * // Replaces all occurrences of 1.0 with 2.0 in column "height" and
"weight".
- * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"firstname" and "lastname".
- * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" ->
"unnamed"));
- * }}}
- *
- * @param cols
- * list of columns to apply the value replacement. If `col` is "*",
replacement is applied on
- * all string, numeric or boolean columns.
- * @param replacement
- * value replacement map. Key and value of `replacement` map must have the
same type, and can
- * only be doubles, strings or booleans. The map value can have nulls.
- * @since 3.4.0
- */
+ /** @inheritdoc */
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
buildReplaceDataFrame(Some(cols), buildReplacement(replacement))
}
@@ -441,4 +154,59 @@ final class DataFrameNaFunctions private[sql]
(sparkSession: SparkSession, root:
case v =>
throw new IllegalArgumentException(s"Unsupported value type
${v.getClass.getName} ($v).")
}
+
+ /** @inheritdoc */
+ override def drop(): DataFrame = super.drop()
+
+ /** @inheritdoc */
+ override def drop(cols: Array[String]): DataFrame = super.drop(cols)
+
+ /** @inheritdoc */
+ override def drop(cols: Seq[String]): DataFrame = super.drop(cols)
+
+ /** @inheritdoc */
+ override def drop(how: String, cols: Array[String]): DataFrame =
super.drop(how, cols)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
+ super.drop(minNonNulls, cols)
+
+ /** @inheritdoc */
+ override def drop(how: String): DataFrame = super.drop(how)
+
+ /** @inheritdoc */
+ override def drop(how: String, cols: Seq[String]): DataFrame =
super.drop(how, cols)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int): DataFrame = super.drop(minNonNulls)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int, cols: Seq[String]): DataFrame =
+ super.drop(minNonNulls, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Long, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Double, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: String, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Boolean, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(valueMap: java.util.Map[String, Any]): DataFrame =
super.fill(valueMap)
+
+ /** @inheritdoc */
+ override def fill(valueMap: Map[String, Any]): DataFrame =
super.fill(valueMap)
+
+ /** @inheritdoc */
+ override def replace[T](col: String, replacement: java.util.Map[T, T]):
DataFrame =
+ super.replace[T](col, replacement)
+
+ /** @inheritdoc */
+ override def replace[T](cols: Array[String], replacement: java.util.Map[T,
T]): DataFrame =
+ super.replace(cols, replacement)
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index d18a76b06a48..ce21f18501a7 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -279,28 +279,10 @@ class Dataset[T] private[sql] (
}
}
- /**
- * Returns a [[DataFrameNaFunctions]] for working with missing data.
- * {{{
- * // Dropping rows containing any null values.
- * ds.na.drop()
- * }}}
- *
- * @group untypedrel
- * @since 3.4.0
- */
+ /** @inheritdoc */
def na: DataFrameNaFunctions = new DataFrameNaFunctions(sparkSession,
plan.getRoot)
- /**
- * Returns a [[DataFrameStatFunctions]] for working statistic functions
support.
- * {{{
- * // Finding frequent items in column with name 'a'.
- * ds.stat.freqItems(Seq("a"))
- * }}}
- *
- * @group untypedrel
- * @since 3.4.0
- */
+ /** @inheritdoc */
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
private def buildJoin(right: Dataset[_])(f: proto.Join.Builder => Unit):
DataFrame = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
similarity index 51%
copy from
sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
copy to
sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
index 2af5bce69087..7400f90992d8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala
@@ -14,20 +14,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
-package org.apache.spark.sql
-
-import java.{lang => jl}
-import java.util.Locale
+package org.apache.spark.sql.api
import scala.jdk.CollectionConverters._
+import _root_.java.util
+
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.ExpressionUtils.column
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
import org.apache.spark.util.ArrayImplicits._
/**
@@ -36,15 +30,14 @@ import org.apache.spark.util.ArrayImplicits._
* @since 1.3.1
*/
@Stable
-final class DataFrameNaFunctions private[sql](df: DataFrame) {
- import df.sparkSession.RichColumn
+abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] {
/**
* Returns a new `DataFrame` that drops rows containing any null or NaN
values.
*
* @since 1.3.1
*/
- def drop(): DataFrame = drop0("any", outputAttributes)
+ def drop(): DS[Row] = drop("any")
/**
* Returns a new `DataFrame` that drops rows containing null or NaN values.
@@ -54,7 +47,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def drop(how: String): DataFrame = drop0(how, outputAttributes)
+ def drop(how: String): DS[Row] = drop(toMinNonNulls(how))
/**
* Returns a new `DataFrame` that drops rows containing any null or NaN
values
@@ -62,7 +55,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq)
+ def drop(cols: Array[String]): DS[Row] = drop(cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing any
null or NaN values
@@ -70,7 +63,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
+ def drop(cols: Seq[String]): DS[Row] = drop(cols.size, cols)
/**
* Returns a new `DataFrame` that drops rows containing null or NaN values
@@ -81,7 +74,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def drop(how: String, cols: Array[String]): DataFrame = drop(how,
cols.toImmutableArraySeq)
+ def drop(how: String, cols: Array[String]): DS[Row] = drop(how,
cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that drops rows containing
null or NaN values
@@ -92,9 +85,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def drop(how: String, cols: Seq[String]): DataFrame = {
- drop0(how, cols.map(df.resolve(_)))
- }
+ def drop(how: String, cols: Seq[String]): DS[Row] = drop(toMinNonNulls(how),
cols)
/**
* Returns a new `DataFrame` that drops rows containing
@@ -102,7 +93,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
{
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
+ def drop(minNonNulls: Int): DS[Row] = drop(Option(minNonNulls))
/**
* Returns a new `DataFrame` that drops rows containing
@@ -110,7 +101,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
+ def drop(minNonNulls: Int, cols: Array[String]): DS[Row] =
drop(minNonNulls, cols.toImmutableArraySeq)
/**
@@ -119,29 +110,39 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
- drop0(minNonNulls, cols.map(df.resolve(_)))
+ def drop(minNonNulls: Int, cols: Seq[String]): DS[Row] =
drop(Option(minNonNulls), cols)
+
+ private def toMinNonNulls(how: String): Option[Int] = {
+ how.toLowerCase(util.Locale.ROOT) match {
+ case "any" => None // No-Op. Do nothing.
+ case "all" => Some(1)
+ case _ => throw new IllegalArgumentException(s"how ($how) must be 'any'
or 'all'")
+ }
}
+ protected def drop(minNonNulls: Option[Int]): DS[Row]
+
+ protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DS[Row]
+
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
*
* @since 2.2.0
*/
- def fill(value: Long): DataFrame = fillValue(value, outputAttributes)
+ def fill(value: Long): DS[Row]
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
* @since 1.3.1
*/
- def fill(value: Double): DataFrame = fillValue(value, outputAttributes)
+ def fill(value: Double): DS[Row]
/**
* Returns a new `DataFrame` that replaces null values in string columns
with `value`.
*
* @since 1.3.1
*/
- def fill(value: String): DataFrame = fillValue(value, outputAttributes)
+ def fill(value: String): DS[Row]
/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns.
@@ -149,7 +150,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 2.2.0
*/
- def fill(value: Long, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
+ def fill(value: Long, cols: Array[String]): DS[Row] = fill(value,
cols.toImmutableArraySeq)
/**
* Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns.
@@ -157,7 +158,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: Double, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
+ def fill(value: Double, cols: Array[String]): DS[Row] = fill(value,
cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
@@ -165,7 +166,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 2.2.0
*/
- def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
+ def fill(value: Long, cols: Seq[String]): DS[Row]
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
@@ -173,7 +174,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
+ def fill(value: Double, cols: Seq[String]): DS[Row]
/**
@@ -182,7 +183,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: String, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
+ def fill(value: String, cols: Array[String]): DS[Row] = fill(value,
cols.toImmutableArraySeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in
@@ -190,14 +191,14 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
+ def fill(value: String, cols: Seq[String]): DS[Row]
/**
* Returns a new `DataFrame` that replaces null values in boolean columns
with `value`.
*
* @since 2.3.0
*/
- def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)
+ def fill(value: Boolean): DS[Row]
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in
specified
@@ -205,7 +206,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 2.3.0
*/
- def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
+ def fill(value: Boolean, cols: Seq[String]): DS[Row]
/**
* Returns a new `DataFrame` that replaces null values in specified boolean
columns.
@@ -213,8 +214,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 2.3.0
*/
- def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
+ def fill(value: Boolean, cols: Array[String]): DS[Row] = fill(value,
cols.toImmutableArraySeq)
/**
* Returns a new `DataFrame` that replaces null values.
@@ -233,7 +233,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(valueMap: java.util.Map[String, Any]): DataFrame =
fillMap(valueMap.asScala.toSeq)
+ def fill(valueMap: util.Map[String, Any]): DS[Row] =
fillMap(valueMap.asScala.toSeq)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values.
@@ -253,7 +253,9 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
+ def fill(valueMap: Map[String, Any]): DS[Row] = fillMap(valueMap.toSeq)
+
+ protected def fillMap(values: Seq[(String, Any)]): DS[Row]
/**
* Replaces values matching keys in `replacement` map with the corresponding
values.
@@ -279,7 +281,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
+ def replace[T](col: String, replacement: util.Map[T, T]): DS[Row] = {
replace[T](col, replacement.asScala.toMap)
}
@@ -304,7 +306,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def replace[T](cols: Array[String], replacement: java.util.Map[T, T]):
DataFrame = {
+ def replace[T](cols: Array[String], replacement: util.Map[T, T]): DS[Row] = {
replace(cols.toImmutableArraySeq, replacement.asScala.toMap)
}
@@ -330,13 +332,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
- if (col == "*") {
- replace0(df.logicalPlan.output, replacement)
- } else {
- replace(Seq(col), replacement)
- }
- }
+ def replace[T](col: String, replacement: Map[T, T]): DS[Row]
/**
* (Scala-specific) Replaces values matching keys in `replacement` map.
@@ -357,194 +353,5 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
*
* @since 1.3.1
*/
- def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
- val attrs = cols.map { colName =>
- // Check column name exists
- val attr = df.resolve(colName) match {
- case a: Attribute => a
- case _ => throw
QueryExecutionErrors.nestedFieldUnsupportedError(colName)
- }
- attr
- }
- replace0(attrs, replacement)
- }
-
- private def replace0[T](attrs: Seq[Attribute], replacement: Map[T, T]):
DataFrame = {
- if (replacement.isEmpty || attrs.isEmpty) {
- return df
- }
-
- // Convert the NumericType in replacement map to DoubleType,
- // while leaving StringType, BooleanType and null untouched.
- val replacementMap: Map[_, _] = replacement.map {
- case (k, v: String) => (k, v)
- case (k, v: Boolean) => (k, v)
- case (k: String, null) => (k, null)
- case (k: Boolean, null) => (k, null)
- case (k, null) => (convertToDouble(k), null)
- case (k, v) => (convertToDouble(k), convertToDouble(v))
- }
-
- // targetColumnType is either DoubleType, StringType or BooleanType,
- // depending on the type of first key in replacement map.
- // Only fields of targetColumnType will perform replacement.
- val targetColumnType = replacement.head._1 match {
- case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long =>
DoubleType
- case _: jl.Boolean => BooleanType
- case _: String => StringType
- }
-
- val output = df.queryExecution.analyzed.output
- val projections = output.map { attr =>
- if (attrs.contains(attr) && (attr.dataType == targetColumnType ||
- (attr.dataType.isInstanceOf[NumericType] && targetColumnType ==
DoubleType))) {
- replaceCol(attr, replacementMap)
- } else {
- column(attr)
- }
- }
- df.select(projections : _*)
- }
-
- private def fillMap(values: Seq[(String, Any)]): DataFrame = {
- // Error handling
- val attrToValue = AttributeMap(values.map { case (colName, replaceValue) =>
- // Check column name exists
- val attr = df.resolve(colName) match {
- case a: Attribute => a
- case _ => throw
QueryExecutionErrors.nestedFieldUnsupportedError(colName)
- }
- // Check data type
- replaceValue match {
- case _: jl.Double | _: jl.Float | _: jl.Integer | _: jl.Long | _:
jl.Boolean | _: String =>
- // This is good
- case _ => throw new IllegalArgumentException(
- s"Unsupported value type ${replaceValue.getClass.getName}
($replaceValue).")
- }
- attr -> replaceValue
- })
-
- val output = df.queryExecution.analyzed.output
- val projections = output.map {
- attr => attrToValue.get(attr).map {
- case v: jl.Float => fillCol[Float](attr, v)
- case v: jl.Double => fillCol[Double](attr, v)
- case v: jl.Long => fillCol[Long](attr, v)
- case v: jl.Integer => fillCol[Integer](attr, v)
- case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue())
- case v: String => fillCol[String](attr, v)
- }.getOrElse(column(attr))
- }
- df.select(projections : _*)
- }
-
- /**
- * Returns a [[Column]] expression that replaces null value in column
defined by `attr`
- * with `replacement`.
- */
- private def fillCol[T](attr: Attribute, replacement: T): Column = {
- fillCol(attr.dataType, attr.name, column(attr), replacement)
- }
-
- /**
- * Returns a [[Column]] expression that replaces null value in `expr` with
`replacement`.
- * It uses the given `expr` as a column.
- */
- private def fillCol[T](dataType: DataType, name: String, expr: Column,
replacement: T): Column = {
- val colValue = dataType match {
- case DoubleType | FloatType =>
- nanvl(expr, lit(null)) // nanvl only supports these types
- case _ => expr
- }
- coalesce(colValue, lit(replacement).cast(dataType)).as(name)
- }
-
- /**
- * Returns a [[Column]] expression that replaces value matching key in
`replacementMap` with
- * value in `replacementMap`, using [[CaseWhen]].
- *
- * TODO: This can be optimized to use broadcast join when replacementMap is
large.
- */
- private def replaceCol[K, V](attr: Attribute, replacementMap: Map[K, V]):
Column = {
- def buildExpr(v: Any) = Cast(Literal(v), attr.dataType)
- val branches = replacementMap.flatMap { case (source, target) =>
- Seq(Literal(source), buildExpr(target))
- }.toSeq
- column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
- }
-
- private def convertToDouble(v: Any): Double = v match {
- case v: Float => v.toDouble
- case v: Double => v
- case v: Long => v.toDouble
- case v: Int => v.toDouble
- case v => throw new IllegalArgumentException(
- s"Unsupported value type ${v.getClass.getName} ($v).")
- }
-
- private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
- cols.map(name => df.col(name).expr).collect {
- case a: Attribute => a
- }
- }
-
- private def outputAttributes: Seq[Attribute] = {
- df.queryExecution.analyzed.output
- }
-
- private def drop0(how: String, cols: Seq[NamedExpression]): DataFrame = {
- how.toLowerCase(Locale.ROOT) match {
- case "any" => drop0(cols.size, cols)
- case "all" => drop0(1, cols)
- case _ => throw new IllegalArgumentException(s"how ($how) must be 'any'
or 'all'")
- }
- }
-
- private def drop0(minNonNulls: Int, cols: Seq[NamedExpression]): DataFrame =
{
- // Filtering condition:
- // only keep the row if it has at least `minNonNulls` non-null and non-NaN
values.
- val predicate = AtLeastNNonNulls(minNonNulls, cols)
- df.filter(column(predicate))
- }
-
- private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame
= {
- fillValue(value, cols.map(toAttributes).getOrElse(outputAttributes))
- }
-
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in the
specified
- * columns. If a specified column is not a numeric, string or boolean column,
- * it is ignored.
- */
- private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
- // the fill[T] which T is Long/Double,
- // should apply on all the NumericType Column, for example:
- // val input = Seq[(java.lang.Integer, java.lang.Double)]((null,
164.3)).toDF("a","b")
- // input.na.fill(3.1)
- // the result is (3,164.3), not (null, 164.3)
- val targetType = value match {
- case _: Double | _: Long => NumericType
- case _: String => StringType
- case _: Boolean => BooleanType
- case _ => throw new IllegalArgumentException(
- s"Unsupported value type ${value.getClass.getName} ($value).")
- }
-
- val projections = outputAttributes.map { col =>
- val typeMatches = (targetType, col.dataType) match {
- case (NumericType, dt) => dt.isInstanceOf[NumericType]
- case (StringType, dt) => dt == StringType
- case (BooleanType, dt) => dt == BooleanType
- case _ =>
- throw new IllegalArgumentException(s"$targetType is not matched at
fillValue")
- }
- // Only fill if the column is part of the cols list.
- if (typeMatches && cols.exists(_.semanticEquals(col))) {
- fillCol(col.dataType, col.name, column(col), value)
- } else {
- column(col)
- }
- }
- df.select(projections : _*)
- }
+ def replace[T](cols: Seq[String], replacement: Map[T, T]): DS[Row]
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
index 226860df6813..49f77a1a6120 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
@@ -539,6 +539,18 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends
Serializable {
// scalastyle:off println
def show(numRows: Int, truncate: Int, vertical: Boolean): Unit
+ /**
+ * Returns a [[DataFrameNaFunctions]] for working with missing data.
+ * {{{
+ * // Dropping rows containing any null values.
+ * ds.na.drop()
+ * }}}
+ *
+ * @group untypedrel
+ * @since 1.6.0
+ */
+ def na: DataFrameNaFunctions[DS]
+
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions
support.
* {{{
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 2af5bce69087..53640f513fc8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -18,9 +18,6 @@
package org.apache.spark.sql
import java.{lang => jl}
-import java.util.Locale
-
-import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.expressions._
@@ -28,7 +25,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.types._
-import org.apache.spark.util.ArrayImplicits._
/**
* Functionality for working with missing data in `DataFrame`s.
@@ -36,300 +32,43 @@ import org.apache.spark.util.ArrayImplicits._
* @since 1.3.1
*/
@Stable
-final class DataFrameNaFunctions private[sql](df: DataFrame) {
+final class DataFrameNaFunctions private[sql](df: DataFrame)
+ extends api.DataFrameNaFunctions[Dataset] {
import df.sparkSession.RichColumn
- /**
- * Returns a new `DataFrame` that drops rows containing any null or NaN
values.
- *
- * @since 1.3.1
- */
- def drop(): DataFrame = drop0("any", outputAttributes)
-
- /**
- * Returns a new `DataFrame` that drops rows containing null or NaN values.
- *
- * If `how` is "any", then drop rows containing any null or NaN values.
- * If `how` is "all", then drop rows only if every column is null or NaN for
that row.
- *
- * @since 1.3.1
- */
- def drop(how: String): DataFrame = drop0(how, outputAttributes)
-
- /**
- * Returns a new `DataFrame` that drops rows containing any null or NaN
values
- * in the specified columns.
- *
- * @since 1.3.1
- */
- def drop(cols: Array[String]): DataFrame = drop(cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing any
null or NaN values
- * in the specified columns.
- *
- * @since 1.3.1
- */
- def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols)
-
- /**
- * Returns a new `DataFrame` that drops rows containing null or NaN values
- * in the specified columns.
- *
- * If `how` is "any", then drop rows containing any null or NaN values in
the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null
or NaN for that row.
- *
- * @since 1.3.1
- */
- def drop(how: String, cols: Array[String]): DataFrame = drop(how,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing
null or NaN values
- * in the specified columns.
- *
- * If `how` is "any", then drop rows containing any null or NaN values in
the specified columns.
- * If `how` is "all", then drop rows only if every specified column is null
or NaN for that row.
- *
- * @since 1.3.1
- */
- def drop(how: String, cols: Seq[String]): DataFrame = {
- drop0(how, cols.map(df.resolve(_)))
+ protected def drop(minNonNulls: Option[Int]): Dataset[Row] = {
+ drop0(minNonNulls, outputAttributes)
}
- /**
- * Returns a new `DataFrame` that drops rows containing
- * less than `minNonNulls` non-null and non-NaN values.
- *
- * @since 1.3.1
- */
- def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns)
-
- /**
- * Returns a new `DataFrame` that drops rows containing
- * less than `minNonNulls` non-null and non-NaN values in the specified
columns.
- *
- * @since 1.3.1
- */
- def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
- drop(minNonNulls, cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that drops rows containing
less than
- * `minNonNulls` non-null and non-NaN values in the specified columns.
- *
- * @since 1.3.1
- */
- def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = {
- drop0(minNonNulls, cols.map(df.resolve(_)))
+ override protected def drop(minNonNulls: Option[Int], cols: Seq[String]):
Dataset[Row] = {
+ drop0(minNonNulls, cols.map(df.resolve))
}
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
- *
- * @since 2.2.0
- */
+ /** @inheritdoc */
def fill(value: Long): DataFrame = fillValue(value, outputAttributes)
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in numeric
columns with `value`.
- * @since 1.3.1
- */
+ /** @inheritdoc */
def fill(value: Double): DataFrame = fillValue(value, outputAttributes)
- /**
- * Returns a new `DataFrame` that replaces null values in string columns
with `value`.
- *
- * @since 1.3.1
- */
+ /** @inheritdoc */
def fill(value: String): DataFrame = fillValue(value, outputAttributes)
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns.
- * If a specified column is not a numeric column, it is ignored.
- *
- * @since 2.2.0
- */
- def fill(value: Long, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * Returns a new `DataFrame` that replaces null or NaN values in specified
numeric columns.
- * If a specified column is not a numeric column, it is ignored.
- *
- * @since 1.3.1
- */
- def fill(value: Double, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
- * numeric columns. If a specified column is not a numeric column, it is
ignored.
- *
- * @since 2.2.0
- */
+ /** @inheritdoc */
def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN
values in specified
- * numeric columns. If a specified column is not a numeric column, it is
ignored.
- *
- * @since 1.3.1
- */
+ /** @inheritdoc */
def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
-
- /**
- * Returns a new `DataFrame` that replaces null values in specified string
columns.
- * If a specified column is not a string column, it is ignored.
- *
- * @since 1.3.1
- */
- def fill(value: String, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in
- * specified string columns. If a specified column is not a string column,
it is ignored.
- *
- * @since 1.3.1
- */
+ /** @inheritdoc */
def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
- /**
- * Returns a new `DataFrame` that replaces null values in boolean columns
with `value`.
- *
- * @since 2.3.0
- */
+ /** @inheritdoc */
def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values in
specified
- * boolean columns. If a specified column is not a boolean column, it is
ignored.
- *
- * @since 2.3.0
- */
+ /** @inheritdoc */
def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value,
toAttributes(cols))
- /**
- * Returns a new `DataFrame` that replaces null values in specified boolean
columns.
- * If a specified column is not a boolean column, it is ignored.
- *
- * @since 2.3.0
- */
- def fill(value: Boolean, cols: Array[String]): DataFrame = fill(value,
cols.toImmutableArraySeq)
-
-
- /**
- * Returns a new `DataFrame` that replaces null values.
- *
- * The key of the map is the column name, and the value of the map is the
replacement value.
- * The value must be of the following type:
- * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`.
- * Replacement values are cast to the column data type.
- *
- * For example, the following replaces null values in column "A" with string
"unknown", and
- * null values in column "B" with numeric value 1.0.
- * {{{
- * import com.google.common.collect.ImmutableMap;
- * df.na.fill(ImmutableMap.of("A", "unknown", "B", 1.0));
- * }}}
- *
- * @since 1.3.1
- */
- def fill(valueMap: java.util.Map[String, Any]): DataFrame =
fillMap(valueMap.asScala.toSeq)
-
- /**
- * (Scala-specific) Returns a new `DataFrame` that replaces null values.
- *
- * The key of the map is the column name, and the value of the map is the
replacement value.
- * The value must be of the following type: `Int`, `Long`, `Float`,
`Double`, `String`, `Boolean`.
- * Replacement values are cast to the column data type.
- *
- * For example, the following replaces null values in column "A" with string
"unknown", and
- * null values in column "B" with numeric value 1.0.
- * {{{
- * df.na.fill(Map(
- * "A" -> "unknown",
- * "B" -> 1.0
- * ))
- * }}}
- *
- * @since 1.3.1
- */
- def fill(valueMap: Map[String, Any]): DataFrame = fillMap(valueMap.toSeq)
-
- /**
- * Replaces values matching keys in `replacement` map with the corresponding
values.
- *
- * {{{
- * import com.google.common.collect.ImmutableMap;
- *
- * // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.na.replace("height", ImmutableMap.of(1.0, 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"name".
- * df.na.replace("name", ImmutableMap.of("UNKNOWN", "unnamed"));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string
columns.
- * df.na.replace("*", ImmutableMap.of("UNKNOWN", "unnamed"));
- * }}}
- *
- * @param col name of the column to apply the value replacement. If `col` is
"*",
- * replacement is applied on all string, numeric or boolean
columns.
- * @param replacement value replacement map. Key and value of `replacement`
map must have
- * the same type, and can only be doubles, strings or
booleans.
- * The map value can have nulls.
- *
- * @since 1.3.1
- */
- def replace[T](col: String, replacement: java.util.Map[T, T]): DataFrame = {
- replace[T](col, replacement.asScala.toMap)
- }
-
- /**
- * Replaces values matching keys in `replacement` map with the corresponding
values.
- *
- * {{{
- * import com.google.common.collect.ImmutableMap;
- *
- * // Replaces all occurrences of 1.0 with 2.0 in column "height" and
"weight".
- * df.na.replace(new String[] {"height", "weight"}, ImmutableMap.of(1.0,
2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"firstname" and "lastname".
- * df.na.replace(new String[] {"firstname", "lastname"},
ImmutableMap.of("UNKNOWN", "unnamed"));
- * }}}
- *
- * @param cols list of columns to apply the value replacement. If `col` is
"*",
- * replacement is applied on all string, numeric or boolean
columns.
- * @param replacement value replacement map. Key and value of `replacement`
map must have
- * the same type, and can only be doubles, strings or
booleans.
- * The map value can have nulls.
- *
- * @since 1.3.1
- */
- def replace[T](cols: Array[String], replacement: java.util.Map[T, T]):
DataFrame = {
- replace(cols.toImmutableArraySeq, replacement.asScala.toMap)
- }
-
- /**
- * (Scala-specific) Replaces values matching keys in `replacement` map.
- *
- * {{{
- * // Replaces all occurrences of 1.0 with 2.0 in column "height".
- * df.na.replace("height", Map(1.0 -> 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"name".
- * df.na.replace("name", Map("UNKNOWN" -> "unnamed"));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in all string
columns.
- * df.na.replace("*", Map("UNKNOWN" -> "unnamed"));
- * }}}
- *
- * @param col name of the column to apply the value replacement. If `col` is
"*",
- * replacement is applied on all string, numeric or boolean
columns.
- * @param replacement value replacement map. Key and value of `replacement`
map must have
- * the same type, and can only be doubles, strings or
booleans.
- * The map value can have nulls.
- *
- * @since 1.3.1
- */
+ /** @inheritdoc */
def replace[T](col: String, replacement: Map[T, T]): DataFrame = {
if (col == "*") {
replace0(df.logicalPlan.output, replacement)
@@ -338,25 +77,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
}
}
- /**
- * (Scala-specific) Replaces values matching keys in `replacement` map.
- *
- * {{{
- * // Replaces all occurrences of 1.0 with 2.0 in column "height" and
"weight".
- * df.na.replace("height" :: "weight" :: Nil, Map(1.0 -> 2.0));
- *
- * // Replaces all occurrences of "UNKNOWN" with "unnamed" in column
"firstname" and "lastname".
- * df.na.replace("firstname" :: "lastname" :: Nil, Map("UNKNOWN" ->
"unnamed"));
- * }}}
- *
- * @param cols list of columns to apply the value replacement. If `col` is
"*",
- * replacement is applied on all string, numeric or boolean
columns.
- * @param replacement value replacement map. Key and value of `replacement`
map must have
- * the same type, and can only be doubles, strings or
booleans.
- * The map value can have nulls.
- *
- * @since 1.3.1
- */
+ /** @inheritdoc */
def replace[T](cols: Seq[String], replacement: Map[T, T]): DataFrame = {
val attrs = cols.map { colName =>
// Check column name exists
@@ -406,7 +127,7 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
df.select(projections : _*)
}
- private def fillMap(values: Seq[(String, Any)]): DataFrame = {
+ protected def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
val attrToValue = AttributeMap(values.map { case (colName, replaceValue) =>
// Check column name exists
@@ -492,18 +213,11 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
df.queryExecution.analyzed.output
}
- private def drop0(how: String, cols: Seq[NamedExpression]): DataFrame = {
- how.toLowerCase(Locale.ROOT) match {
- case "any" => drop0(cols.size, cols)
- case "all" => drop0(1, cols)
- case _ => throw new IllegalArgumentException(s"how ($how) must be 'any'
or 'all'")
- }
- }
- private def drop0(minNonNulls: Int, cols: Seq[NamedExpression]): DataFrame =
{
+ private def drop0(minNonNulls: Option[Int], cols: Seq[NamedExpression]):
DataFrame = {
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN
values.
- val predicate = AtLeastNNonNulls(minNonNulls, cols)
+ val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols)
df.filter(column(predicate))
}
@@ -547,4 +261,58 @@ final class DataFrameNaFunctions private[sql](df:
DataFrame) {
}
df.select(projections : _*)
}
+
+ /** @inheritdoc */
+ override def drop(): DataFrame = super.drop()
+
+ /** @inheritdoc */
+ override def drop(cols: Array[String]): DataFrame = super.drop(cols)
+
+ /** @inheritdoc */
+ override def drop(cols: Seq[String]): DataFrame = super.drop(cols)
+
+ /** @inheritdoc */
+ override def drop(how: String, cols: Array[String]): DataFrame =
super.drop(how, cols)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int, cols: Array[String]): DataFrame =
+ super.drop(minNonNulls, cols)
+
+ /** @inheritdoc */
+ override def drop(how: String): DataFrame = super.drop(how)
+
+ /** @inheritdoc */
+ override def drop(how: String, cols: Seq[String]): DataFrame =
super.drop(how, cols)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int): DataFrame = super.drop(minNonNulls)
+
+ /** @inheritdoc */
+ override def drop(minNonNulls: Int, cols: Seq[String]): DataFrame =
super.drop(minNonNulls, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Long, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Double, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: String, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(value: Boolean, cols: Array[String]): DataFrame =
super.fill(value, cols)
+
+ /** @inheritdoc */
+ override def fill(valueMap: java.util.Map[String, Any]): DataFrame =
super.fill(valueMap)
+
+ /** @inheritdoc */
+ override def fill(valueMap: Map[String, Any]): DataFrame =
super.fill(valueMap)
+
+ /** @inheritdoc */
+ override def replace[T](col: String, replacement: java.util.Map[T, T]):
DataFrame =
+ super.replace[T](col, replacement)
+
+ /** @inheritdoc */
+ override def replace[T](cols: Array[String], replacement: java.util.Map[T,
T]): DataFrame =
+ super.replace(cols, replacement)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 5288d77d40b1..38521e8e16f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -577,28 +577,10 @@ class Dataset[T] private[sql](
println(showString(numRows, truncate, vertical))
// scalastyle:on println
- /**
- * Returns a [[DataFrameNaFunctions]] for working with missing data.
- * {{{
- * // Dropping rows containing any null values.
- * ds.na.drop()
- * }}}
- *
- * @group untypedrel
- * @since 1.6.0
- */
+ /** @inheritdoc */
def na: DataFrameNaFunctions = new DataFrameNaFunctions(toDF())
- /**
- * Returns a [[DataFrameStatFunctions]] for working statistic functions
support.
- * {{{
- * // Finding frequent items in column with name 'a'.
- * ds.stat.freqItems(Seq("a"))
- * }}}
- *
- * @group untypedrel
- * @since 1.6.0
- */
+ /** @inheritdoc */
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(toDF())
/** @inheritdoc */
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]