This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 6112d78cba2 [SPARK-45052][SQL][PYTHON][CONNECT][3.5] Make function aliases output column name consistent with SQL 6112d78cba2 is described below commit 6112d78cba20fd2e9aa298190371dd52205dc762 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Mon Sep 4 16:24:43 2023 +0800 [SPARK-45052][SQL][PYTHON][CONNECT][3.5] Make function aliases output column name consistent with SQL ### What changes were proposed in this pull request? backport https://github.com/apache/spark/pull/42775 to 3.5 ### Why are the changes needed? to make `func(col)` consistent with `expr(func(col))` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #42786 from zhengruifeng/try_column_name_35. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../scala/org/apache/spark/sql/functions.scala | 12 +- .../query-tests/explain-results/describe.explain | 2 +- .../explain-results/function_ceiling.explain | 2 +- .../explain-results/function_ceiling_scale.explain | 2 +- .../explain-results/function_printf.explain | 2 +- .../explain-results/function_sign.explain | 2 +- .../explain-results/function_std.explain | 2 +- .../query-tests/queries/function_ceiling.json | 2 +- .../query-tests/queries/function_ceiling.proto.bin | Bin 173 -> 176 bytes .../queries/function_ceiling_scale.json | 2 +- .../queries/function_ceiling_scale.proto.bin | Bin 179 -> 182 bytes .../query-tests/queries/function_printf.json | 2 +- .../query-tests/queries/function_printf.proto.bin | Bin 196 -> 189 bytes .../query-tests/queries/function_sign.json | 2 +- .../query-tests/queries/function_sign.proto.bin | Bin 175 -> 173 bytes .../query-tests/queries/function_std.json | 2 +- .../query-tests/queries/function_std.proto.bin | Bin 175 -> 172 bytes python/pyspark/sql/connect/functions.py | 26 +- python/pyspark/sql/functions.py | 714 +++++++++++++++------ .../scala/org/apache/spark/sql/functions.scala | 202 +++--- 20 files changed, 628 insertions(+), 348 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index fa8c5782e06..fe992ae6740 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -987,7 +987,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = stddev(e) + def std(e: Column): Column = Column.fn("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -2337,7 +2337,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = ceil(e, scale) + def ceiling(e: Column, scale: Column): Column = Column.fn("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2345,7 +2345,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = ceil(e) + def ceiling(e: Column): Column = Column.fn("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2800,7 +2800,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def power(l: Column, r: Column): Column = pow(l, r) + def power(l: Column, r: Column): Column = Column.fn("power", l, r) /** * Returns the positive value of dividend mod divisor. @@ -2937,7 +2937,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = signum(e) + def sign(e: Column): Column = Column.fn("sign", e) /** * Computes the signum of the given value. @@ -4420,7 +4420,7 @@ object functions { * @since 3.5.0 */ def printf(format: Column, arguments: Column*): Column = - Column.fn("format_string", lit(format) +: arguments: _*) + Column.fn("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format using a specific encoding diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain index f205f7ef7a1..b203f715c71 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/describe.explain @@ -1,6 +1,6 @@ Project [summary#0, element_at(id#0, summary#0, None, false) AS id#0, element_at(b#0, summary#0, None, false) AS b#0] +- Project [id#0, b#0, summary#0] +- Generate explode([count,mean,stddev,min,max]), false, [summary#0] - +- Aggregate [map(cast(count as string), cast(count(id#0L) as string), cast(mean as string), cast(avg(id#0L) as string), cast(stddev as string), cast(stddev_samp(cast(id#0L as double)) as string), cast(min as string), cast(min(id#0L) as string), cast(max as string), cast(max(id#0L) as string)) AS id#0, map(cast(count as string), cast(count(b#0) as string), cast(mean as string), cast(avg(b#0) as string), cast(stddev as string), cast(stddev_samp(b#0) as string), cast(min as string), [...] + +- Aggregate [map(cast(count as string), cast(count(id#0L) as string), cast(mean as string), cast(avg(id#0L) as string), cast(stddev as string), cast(stddev(cast(id#0L as double)) as string), cast(min as string), cast(min(id#0L) as string), cast(max as string), cast(max(id#0L) as string)) AS id#0, map(cast(count as string), cast(count(b#0) as string), cast(mean as string), cast(avg(b#0) as string), cast(stddev as string), cast(stddev(b#0) as string), cast(min as string), cast(min(b [...] +- Project [id#0L, b#0] +- LocalRelation <empty>, [id#0L, a#0, b#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain index 9cf776a8dba..217d7434b80 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling.explain @@ -1,2 +1,2 @@ -Project [CEIL(b#0) AS CEIL(b)#0L] +Project [ceiling(b#0) AS ceiling(b)#0L] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain index cdf8d356e47..2c41c12278b 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_ceiling_scale.explain @@ -1,2 +1,2 @@ -Project [ceil(cast(b#0 as decimal(30,15)), 2) AS ceil(b, 2)#0] +Project [ceiling(cast(b#0 as decimal(30,15)), 2) AS ceiling(b, 2)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain index 10409df0070..8d55d773400 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_printf.explain @@ -1,2 +1,2 @@ -Project [format_string(g#0, a#0, g#0) AS format_string(g, a, g)#0] +Project [printf(g#0, a#0, g#0) AS printf(g, a, g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain index 807fa330083..5d41e16b6ce 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_sign.explain @@ -1,2 +1,2 @@ -Project [SIGNUM(b#0) AS SIGNUM(b)#0] +Project [sign(b#0) AS sign(b)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain index 106191e5a32..cf5b86ae3a5 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_std.explain @@ -1,2 +1,2 @@ -Aggregate [stddev(cast(a#0 as double)) AS stddev(a)#0] +Aggregate [std(cast(a#0 as double)) AS std(a)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json index 5a9961ab47f..99726305e85 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "ceil", + "functionName": "ceiling", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin index 3761deb1663..cc91ac246a5 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json index bda5e85924c..c0b0742b121 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "ceil", + "functionName": "ceiling", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin index 8db402ac167..30efc42b9d2 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_ceiling_scale.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json index dc7ca880c4b..73ca595e865 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "format_string", + "functionName": "printf", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "g" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin index 7ebdda6cac1..3fb3862f44d 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_printf.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json index bcf6ad7eb17..34451969078 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "signum", + "functionName": "sign", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "b" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin index af52abfb7f2..ff866c97303 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_sign.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_std.json b/connector/connect/common/src/test/resources/query-tests/queries/function_std.json index 1403817886c..cbdb4ea9e5e 100644 --- a/connector/connect/common/src/test/resources/query-tests/queries/function_std.json +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_std.json @@ -13,7 +13,7 @@ }, "expressions": [{ "unresolvedFunction": { - "functionName": "stddev", + "functionName": "std", "arguments": [{ "unresolvedAttribute": { "unparsedIdentifier": "a" diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin index 8d214eea8e7..7e34b0427c2 100644 Binary files a/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin and b/connector/connect/common/src/test/resources/query-tests/queries/function_std.proto.bin differ diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 0c511d6c07c..e2583f84c41 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -558,7 +558,11 @@ def ceil(col: "ColumnOrName") -> Column: ceil.__doc__ = pysparkfuncs.ceil.__doc__ -ceiling = ceil +def ceiling(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("ceiling", col) + + +ceiling.__doc__ = pysparkfuncs.ceiling.__doc__ def conv(col: "ColumnOrName", fromBase: int, toBase: int) -> Column: @@ -823,7 +827,11 @@ def signum(col: "ColumnOrName") -> Column: signum.__doc__ = pysparkfuncs.signum.__doc__ -sign = signum +def sign(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("sign", col) + + +sign.__doc__ = pysparkfuncs.sign.__doc__ def sin(col: "ColumnOrName") -> Column: @@ -1199,13 +1207,17 @@ skewness.__doc__ = pysparkfuncs.skewness.__doc__ def stddev(col: "ColumnOrName") -> Column: - return stddev_samp(col) + return _invoke_function_over_columns("stddev", col) stddev.__doc__ = pysparkfuncs.stddev.__doc__ -std = stddev +def std(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns("std", col) + + +std.__doc__ = pysparkfuncs.std.__doc__ def stddev_samp(col: "ColumnOrName") -> Column: @@ -1329,7 +1341,7 @@ variance.__doc__ = pysparkfuncs.variance.__doc__ def every(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("bool_and", col) + return _invoke_function_over_columns("every", col) every.__doc__ = pysparkfuncs.every.__doc__ @@ -1343,7 +1355,7 @@ bool_and.__doc__ = pysparkfuncs.bool_and.__doc__ def some(col: "ColumnOrName") -> Column: - return _invoke_function_over_columns("bool_or", col) + return _invoke_function_over_columns("some", col) some.__doc__ = pysparkfuncs.some.__doc__ @@ -2558,7 +2570,7 @@ parse_url.__doc__ = pysparkfuncs.parse_url.__doc__ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: - return _invoke_function("printf", lit(format), *[_to_col(c) for c in cols]) + return _invoke_function("printf", _to_col(format), *[_to_col(c) for c in cols]) printf.__doc__ = pysparkfuncs.printf.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 4aafcf90d8f..1fe2f7d40a2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -412,9 +412,15 @@ def try_avg(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(1982, 15), (1990, 2)], ["birth", "age"]) - >>> df.select(try_avg(df.age).alias('r')).collect() - [Row(r=8.5)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(1982, 15), (1990, 2)], ["birth", "age"] + ... ).select(sf.try_avg("age")).show() + +------------+ + |try_avg(age)| + +------------+ + | 8.5| + +------------+ """ return _invoke_function_over_columns("try_avg", col) @@ -565,9 +571,13 @@ def try_sum(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(10) - >>> df.select(try_sum(df["id"]).alias('r')).collect() - [Row(r=45)] + >>> import pyspark.sql.functions as sf + >>> spark.range(10).select(sf.try_sum("id")).show() + +-----------+ + |try_sum(id)| + +-----------+ + | 45| + +-----------+ """ return _invoke_function_over_columns("try_sum", col) @@ -1316,7 +1326,37 @@ def ceil(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("ceil", col) -ceiling = ceil +@try_remote_functions +def ceiling(col: "ColumnOrName") -> Column: + """ + Computes the ceiling of the given value. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.ceil(sf.lit(-0.1))).show() + +----------+ + |CEIL(-0.1)| + +----------+ + | 0| + +----------+ + """ + return _invoke_function_over_columns("ceiling", col) @try_remote_functions @@ -1668,14 +1708,15 @@ def negative(col: "ColumnOrName") -> Column: Examples -------- - >>> spark.range(3).select(negative("id").alias("n")).show() - +---+ - | n| - +---+ - | 0| - | -1| - | -2| - +---+ + >>> import pyspark.sql.functions as sf + >>> spark.range(3).select(sf.negative("id")).show() + +------------+ + |negative(id)| + +------------+ + | 0| + | -1| + | -2| + +------------+ """ return _invoke_function_over_columns("negative", col) @@ -1825,25 +1866,54 @@ def signum(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(signum(lit(-5))).show() - +----------+ - |SIGNUM(-5)| - +----------+ - | -1.0| - +----------+ - - >>> df.select(signum(lit(6))).show() - +---------+ - |SIGNUM(6)| - +---------+ - | 1.0| - +---------+ + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.signum(sf.lit(-5)), + ... sf.signum(sf.lit(6)) + ... ).show() + +----------+---------+ + |SIGNUM(-5)|SIGNUM(6)| + +----------+---------+ + | -1.0| 1.0| + +----------+---------+ """ return _invoke_function_over_columns("signum", col) -sign = signum +@try_remote_functions +def sign(col: "ColumnOrName") -> Column: + """ + Computes the signum of the given value. + + .. versionadded:: 1.4.0 + + .. versionchanged:: 3.4.0 + Supports Spark Connect. + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + the column for computed results. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.sign(sf.lit(-5)), + ... sf.sign(sf.lit(6)) + ... ).show() + +--------+-------+ + |sign(-5)|sign(6)| + +--------+-------+ + | -1.0| 1.0| + +--------+-------+ + """ + return _invoke_function_over_columns("sign", col) @try_remote_functions @@ -2146,15 +2216,17 @@ def getbit(col: "ColumnOrName", pos: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) - >>> df.select(getbit("c", lit(1)).alias("d")).show() - +---+ - | d| - +---+ - | 0| - | 0| - | 1| - +---+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[1], [1], [2]], ["c"] + ... ).select(sf.getbit("c", sf.lit(1))).show() + +------------+ + |getbit(c, 1)| + +------------+ + | 0| + | 0| + | 1| + +------------+ """ return _invoke_function_over_columns("getbit", col, pos) @@ -2351,14 +2423,45 @@ def stddev(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev(df.id)).first() - Row(stddev_samp(id)=1.87082...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev("id")).show() + +------------------+ + | stddev(id)| + +------------------+ + |1.8708286933869...| + +------------------+ """ return _invoke_function_over_columns("stddev", col) -std = stddev +@try_remote_functions +def std(col: "ColumnOrName") -> Column: + """ + Aggregate function: alias for stddev_samp. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or str + target column to compute on. + + Returns + ------- + :class:`~pyspark.sql.Column` + standard deviation of given column. + + Examples + -------- + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.std("id")).show() + +------------------+ + | std(id)| + +------------------+ + |1.8708286933869...| + +------------------+ + """ + return _invoke_function_over_columns("std", col) @try_remote_functions @@ -2384,9 +2487,13 @@ def stddev_samp(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev_samp(df.id)).first() - Row(stddev_samp(id)=1.87082...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev_samp("id")).show() + +------------------+ + | stddev_samp(id)| + +------------------+ + |1.8708286933869...| + +------------------+ """ return _invoke_function_over_columns("stddev_samp", col) @@ -2414,9 +2521,13 @@ def stddev_pop(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.range(6) - >>> df.select(stddev_pop(df.id)).first() - Row(stddev_pop(id)=1.70782...) + >>> import pyspark.sql.functions as sf + >>> spark.range(6).select(sf.stddev_pop("id")).show() + +-----------------+ + | stddev_pop(id)| + +-----------------+ + |1.707825127659...| + +-----------------+ """ return _invoke_function_over_columns("stddev_pop", col) @@ -2816,27 +2927,35 @@ def every(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[True], [True], [True]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | true| - +--------------+ - >>> df = spark.createDataFrame([[True], [False], [True]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | false| - +--------------+ - >>> df = spark.createDataFrame([[False], [False], [False]], ["flag"]) - >>> df.select(every("flag")).show() - +--------------+ - |bool_and(flag)| - +--------------+ - | false| - +--------------+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [True], [True]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | true| + +-----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [False], [True]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | false| + +-----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[False], [False], [False]], ["flag"] + ... ).select(sf.every("flag")).show() + +-----------+ + |every(flag)| + +-----------+ + | false| + +-----------+ """ return _invoke_function_over_columns("every", col) @@ -2904,27 +3023,35 @@ def some(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([[True], [True], [True]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | true| - +-------------+ - >>> df = spark.createDataFrame([[True], [False], [True]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | true| - +-------------+ - >>> df = spark.createDataFrame([[False], [False], [False]], ["flag"]) - >>> df.select(some("flag")).show() - +-------------+ - |bool_or(flag)| - +-------------+ - | false| - +-------------+ + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [True], [True]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | true| + +----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[True], [False], [True]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | true| + +----------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [[False], [False], [False]], ["flag"] + ... ).select(sf.some("flag")).show() + +----------+ + |some(flag)| + +----------+ + | false| + +----------+ """ return _invoke_function_over_columns("some", col) @@ -4546,22 +4673,23 @@ def approx_percentile( Examples -------- - >>> key = (col("id") % 3).alias("key") - >>> value = (randn(42) + key * 10).alias("value") + >>> import pyspark.sql.functions as sf + >>> key = (sf.col("id") % 3).alias("key") + >>> value = (sf.randn(42) + key * 10).alias("value") >>> df = spark.range(0, 1000, 1, 1).select(key, value) >>> df.select( - ... approx_percentile("value", [0.25, 0.5, 0.75], 1000000).alias("quantiles") + ... sf.approx_percentile("value", [0.25, 0.5, 0.75], 1000000) ... ).printSchema() root - |-- quantiles: array (nullable = true) + |-- approx_percentile(value, array(0.25, 0.5, 0.75), 1000000): array (nullable = true) | |-- element: double (containsNull = false) >>> df.groupBy("key").agg( - ... approx_percentile("value", 0.5, lit(1000000)).alias("median") + ... sf.approx_percentile("value", 0.5, sf.lit(1000000)) ... ).printSchema() root |-- key: long (nullable = true) - |-- median: double (nullable = true) + |-- approx_percentile(value, 0.5, 1000000): double (nullable = true) """ sc = get_active_spark_context() @@ -5603,15 +5731,25 @@ def first_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] Examples -------- - >>> df = spark.createDataFrame([(None, 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... ("b", 2)], ["c1", "c2"]) - >>> df.select(first_value('c1').alias('a'), first_value('c2').alias('b')).collect() - [Row(a=None, b=1)] - >>> df.select(first_value('c1', True).alias('a'), first_value('c2', True).alias('b')).collect() - [Row(a='a', b=1)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["a", "b"] + ... ).select(sf.first_value('a'), sf.first_value('b')).show() + +--------------+--------------+ + |first_value(a)|first_value(b)| + +--------------+--------------+ + | NULL| 1| + +--------------+--------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(None, 1), ("a", 2), ("a", 3), ("b", 8), ("b", 2)], ["a", "b"] + ... ).select(sf.first_value('a', True), sf.first_value('b', True)).show() + +--------------+--------------+ + |first_value(a)|first_value(b)| + +--------------+--------------+ + | a| 1| + +--------------+--------------+ """ if ignoreNulls is None: return _invoke_function_over_columns("first_value", col) @@ -5641,15 +5779,25 @@ def last_value(col: "ColumnOrName", ignoreNulls: Optional[Union[bool, Column]] = Examples -------- - >>> df = spark.createDataFrame([("a", 1), - ... ("a", 2), - ... ("a", 3), - ... ("b", 8), - ... (None, 2)], ["c1", "c2"]) - >>> df.select(last_value('c1').alias('a'), last_value('c2').alias('b')).collect() - [Row(a=None, b=2)] - >>> df.select(last_value('c1', True).alias('a'), last_value('c2', True).alias('b')).collect() - [Row(a='b', b=2)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), (None, 2)], ["a", "b"] + ... ).select(sf.last_value('a'), sf.last_value('b')).show() + +-------------+-------------+ + |last_value(a)|last_value(b)| + +-------------+-------------+ + | NULL| 2| + +-------------+-------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("a", 1), ("a", 2), ("a", 3), ("b", 8), (None, 2)], ["a", "b"] + ... ).select(sf.last_value('a', True), sf.last_value('b', True)).show() + +-------------+-------------+ + |last_value(a)|last_value(b)| + +-------------+-------------+ + | b| 2| + +-------------+-------------+ """ if ignoreNulls is None: return _invoke_function_over_columns("last_value", col) @@ -5811,8 +5959,8 @@ def curdate() -> Column: Examples -------- - >>> df = spark.range(1) - >>> df.select(curdate()).show() # doctest: +SKIP + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.curdate()).show() # doctest: +SKIP +--------------+ |current_date()| +--------------+ @@ -6558,13 +6706,35 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: Examples -------- - >>> df = spark.createDataFrame([('2015-04-08', 2,)], ['dt', 'add']) - >>> df.select(dateadd(df.dt, 1).alias('next_date')).collect() - [Row(next_date=datetime.date(2015, 4, 9))] - >>> df.select(dateadd(df.dt, df.add.cast('integer')).alias('next_date')).collect() - [Row(next_date=datetime.date(2015, 4, 10))] - >>> df.select(dateadd('dt', -1).alias('prev_date')).collect() - [Row(prev_date=datetime.date(2015, 4, 7))] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", 1)).show() + +---------------+ + |date_add(dt, 1)| + +---------------+ + | 2015-04-09| + +---------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", sf.lit(2))).show() + +---------------+ + |date_add(dt, 2)| + +---------------+ + | 2015-04-10| + +---------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('2015-04-08', 2,)], ['dt', 'add'] + ... ).select(sf.dateadd("dt", -1)).show() + +----------------+ + |date_add(dt, -1)| + +----------------+ + | 2015-04-07| + +----------------+ """ days = lit(days) if isinstance(days, int) else days return _invoke_function_over_columns("dateadd", start, days) @@ -7031,9 +7201,15 @@ def xpath_number(xml: "ColumnOrName", path: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([('<a><b>1</b><b>2</b></a>',)], ['x']) - >>> df.select(xpath_number(df.x, lit('sum(a/b)')).alias('r')).collect() - [Row(r=3.0)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [('<a><b>1</b><b>2</b></a>',)], ['x'] + ... ).select(sf.xpath_number('x', sf.lit('sum(a/b)'))).show() + +-------------------------+ + |xpath_number(x, sum(a/b))| + +-------------------------+ + | 3.0| + +-------------------------+ """ return _invoke_function_over_columns("xpath_number", xml, path) @@ -7926,7 +8102,8 @@ def current_schema() -> Column: Examples -------- - >>> spark.range(1).select(current_schema()).show() + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.current_schema()).show() +------------------+ |current_database()| +------------------+ @@ -7962,7 +8139,8 @@ def user() -> Column: Examples -------- - >>> spark.range(1).select(user()).show() # doctest: +SKIP + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.user()).show() # doctest: +SKIP +--------------+ |current_user()| +--------------+ @@ -9228,13 +9406,35 @@ def regexp(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - >>> df.select(regexp('str', lit(r'(\d+)')).alias('d')).collect() - [Row(d=True)] - >>> df.select(regexp('str', lit(r'\d{2}b')).alias('d')).collect() - [Row(d=False)] - >>> df.select(regexp("str", col("regexp")).alias('d')).collect() - [Row(d=True)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.lit(r'(\d+)'))).show() + +------------------+ + |REGEXP(str, (\d+))| + +------------------+ + | true| + +------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.lit(r'\d{2}b'))).show() + +-------------------+ + |REGEXP(str, \d{2}b)| + +-------------------+ + | false| + +-------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp('str', sf.col("regexp"))).show() + +-------------------+ + |REGEXP(str, regexp)| + +-------------------+ + | true| + +-------------------+ """ return _invoke_function_over_columns("regexp", str, regexp) @@ -9259,13 +9459,35 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("1a 2b 14m", r"(\d+)")], ["str", "regexp"]) - >>> df.select(regexp_like('str', lit(r'(\d+)')).alias('d')).collect() - [Row(d=True)] - >>> df.select(regexp_like('str', lit(r'\d{2}b')).alias('d')).collect() - [Row(d=False)] - >>> df.select(regexp_like("str", col("regexp")).alias('d')).collect() - [Row(d=True)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.lit(r'(\d+)'))).show() + +-----------------------+ + |REGEXP_LIKE(str, (\d+))| + +-----------------------+ + | true| + +-----------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.lit(r'\d{2}b'))).show() + +------------------------+ + |REGEXP_LIKE(str, \d{2}b)| + +------------------------+ + | false| + +------------------------+ + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("1a 2b 14m", r"(\d+)")], ["str", "regexp"] + ... ).select(sf.regexp_like('str', sf.col("regexp"))).show() + +------------------------+ + |REGEXP_LIKE(str, regexp)| + +------------------------+ + | true| + +------------------------+ """ return _invoke_function_over_columns("regexp_like", str, regexp) @@ -10006,12 +10228,25 @@ def substr( Examples -------- - >>> df = spark.createDataFrame([("Spark SQL", 5, 1,)], ["a", "b", "c"]) - >>> df.select(substr(df.a, df.b, df.c).alias('r')).collect() - [Row(r='k')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... ).select(sf.substr("a", "b", "c")).show() + +---------------+ + |substr(a, b, c)| + +---------------+ + | k| + +---------------+ - >>> df.select(substr(df.a, df.b).alias('r')).collect() - [Row(r='k SQL')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("Spark SQL", 5, 1,)], ["a", "b", "c"] + ... ).select(sf.substr("a", "b")).show() + +------------------------+ + |substr(a, b, 2147483647)| + +------------------------+ + | k SQL| + +------------------------+ """ if len is not None: return _invoke_function_over_columns("substr", str, pos, len) @@ -10071,9 +10306,15 @@ def printf(format: "ColumnOrName", *cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("aa%d%s", 123, "cc",)], ["a", "b", "c"]) - >>> df.select(printf(df.a, df.b, df.c).alias('r')).collect() - [Row(r='aa123cc')] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("aa%d%s", 123, "cc",)], ["a", "b", "c"] + ... ).select(sf.printf("a", "b", "c")).show() + +---------------+ + |printf(a, b, c)| + +---------------+ + | aa123cc| + +---------------+ """ sc = get_active_spark_context() return _invoke_function("printf", _to_java_column(format), _to_seq(sc, cols, _to_java_column)) @@ -10144,12 +10385,24 @@ def position( Examples -------- - >>> df = spark.createDataFrame([("bar", "foobarbar", 5,)], ["a", "b", "c"]) - >>> df.select(position(df.a, df.b, df.c).alias('r')).collect() - [Row(r=7)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [("bar", "foobarbar", 5,)], ["a", "b", "c"] + ... ).select(sf.position("a", "b", "c")).show() + +-----------------+ + |position(a, b, c)| + +-----------------+ + | 7| + +-----------------+ - >>> df.select(position(df.a, df.b).alias('r')).collect() - [Row(r=4)] + >>> spark.createDataFrame( + ... [("bar", "foobarbar", 5,)], ["a", "b", "c"] + ... ).select(sf.position("a", "b")).show() + +-----------------+ + |position(a, b, 1)| + +-----------------+ + | 4| + +-----------------+ """ if start is not None: return _invoke_function_over_columns("position", substr, str, start) @@ -10248,9 +10501,13 @@ def char(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([(65,)], ['a']) - >>> df.select(char(df.a).alias('r')).collect() - [Row(r='A')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.char(sf.lit(65))).show() + +--------+ + |char(65)| + +--------+ + | A| + +--------+ """ return _invoke_function_over_columns("char", col) @@ -10301,9 +10558,13 @@ def char_length(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("SparkSQL",)], ['a']) - >>> df.select(char_length(df.a).alias('r')).collect() - [Row(r=8)] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.char_length(sf.lit("SparkSQL"))).show() + +---------------------+ + |char_length(SparkSQL)| + +---------------------+ + | 8| + +---------------------+ """ return _invoke_function_over_columns("char_length", str) @@ -10324,9 +10585,13 @@ def character_length(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("SparkSQL",)], ['a']) - >>> df.select(character_length(df.a).alias('r')).collect() - [Row(r=8)] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.character_length(sf.lit("SparkSQL"))).show() + +--------------------------+ + |character_length(SparkSQL)| + +--------------------------+ + | 8| + +--------------------------+ """ return _invoke_function_over_columns("character_length", str) @@ -10589,9 +10854,13 @@ def lcase(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ['a']) - >>> df.select(lcase(df.a).alias('r')).collect() - [Row(r='spark')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.lcase(sf.lit("Spark"))).show() + +------------+ + |lcase(Spark)| + +------------+ + | spark| + +------------+ """ return _invoke_function_over_columns("lcase", str) @@ -10610,9 +10879,13 @@ def ucase(str: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ['a']) - >>> df.select(ucase(df.a).alias('r')).collect() - [Row(r='SPARK')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.ucase(sf.lit("Spark"))).show() + +------------+ + |ucase(Spark)| + +------------+ + | SPARK| + +------------+ """ return _invoke_function_over_columns("ucase", str) @@ -12318,9 +12591,17 @@ def cardinality(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) - >>> df.select(cardinality(df.data).alias('r')).collect() - [Row(r=3), Row(r=1), Row(r=0)] + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [([1, 2, 3],),([1],),([],)], ['data'] + ... ).select(sf.cardinality("data")).show() + +-----------------+ + |cardinality(data)| + +-----------------+ + | 3| + | 1| + | 0| + +-----------------+ """ return _invoke_function_over_columns("cardinality", col) @@ -14177,26 +14458,27 @@ def make_timestamp_ltz( Examples -------- + >>> import pyspark.sql.functions as sf >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887, 'CET']], ... ["year", "month", "day", "hour", "min", "sec", "timezone"]) - >>> df.select(make_timestamp_ltz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone).alias('r') + >>> df.select(sf.make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec, df.timezone) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-27 21:30:45.887| - +-----------------------+ - - >>> df.select(make_timestamp_ltz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec).alias('r') + +--------------------------------------------------------------+ + |make_timestamp_ltz(year, month, day, hour, min, sec, timezone)| + +--------------------------------------------------------------+ + |2014-12-27 21:30:45.887 | + +--------------------------------------------------------------+ + + >>> df.select(sf.make_timestamp_ltz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-28 06:30:45.887| - +-----------------------+ + +----------------------------------------------------+ + |make_timestamp_ltz(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +----------------------------------------------------+ >>> spark.conf.unset("spark.sql.session.timeZone") """ if timezone is not None: @@ -14245,17 +14527,18 @@ def make_timestamp_ntz( Examples -------- + >>> import pyspark.sql.functions as sf >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles") >>> df = spark.createDataFrame([[2014, 12, 28, 6, 30, 45.887]], ... ["year", "month", "day", "hour", "min", "sec"]) - >>> df.select(make_timestamp_ntz( - ... df.year, df.month, df.day, df.hour, df.min, df.sec).alias('r') + >>> df.select(sf.make_timestamp_ntz( + ... df.year, df.month, df.day, df.hour, df.min, df.sec) ... ).show(truncate=False) - +-----------------------+ - |r | - +-----------------------+ - |2014-12-28 06:30:45.887| - +-----------------------+ + +----------------------------------------------------+ + |make_timestamp_ntz(year, month, day, hour, min, sec)| + +----------------------------------------------------+ + |2014-12-28 06:30:45.887 | + +----------------------------------------------------+ >>> spark.conf.unset("spark.sql.session.timeZone") """ return _invoke_function_over_columns( @@ -14689,9 +14972,15 @@ def ifnull(col1: "ColumnOrName", col2: "ColumnOrName") -> Column: Examples -------- + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([(None,), (1,)], ["e"]) - >>> df.select(ifnull(df.e, lit(8)).alias('r')).collect() - [Row(r=8), Row(r=1)] + >>> df.select(sf.ifnull(df.e, sf.lit(8))).show() + +------------+ + |ifnull(e, 8)| + +------------+ + | 8| + | 1| + +------------+ """ return _invoke_function_over_columns("ifnull", col1, col2) @@ -15057,9 +15346,13 @@ def sha(col: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("Spark",)], ["a"]) - >>> df.select(sha(df.a).alias('r')).collect() - [Row(r='85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c')] + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select(sf.sha(sf.lit("Spark"))).show() + +--------------------+ + | sha(Spark)| + +--------------------+ + |85f5955f4b27a9a4c...| + +--------------------+ """ return _invoke_function_over_columns("sha", col) @@ -15137,12 +15430,19 @@ def java_method(*cols: "ColumnOrName") -> Column: Examples -------- - >>> df = spark.createDataFrame([("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2",)], ["a"]) - >>> df.select( - ... java_method(lit("java.util.UUID"), lit("fromString"), df.a).alias('r') - ... ).collect() - [Row(r='a5cf6c42-0c85-418f-af6c-3e4e5b1328f2')] - + >>> import pyspark.sql.functions as sf + >>> spark.range(1).select( + ... sf.java_method( + ... sf.lit("java.util.UUID"), + ... sf.lit("fromString"), + ... sf.lit("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2") + ... ) + ... ).show(truncate=False) + +-----------------------------------------------------------------------------+ + |java_method(java.util.UUID, fromString, a5cf6c42-0c85-418f-af6c-3e4e5b1328f2)| + +-----------------------------------------------------------------------------+ + |a5cf6c42-0c85-418f-af6c-3e4e5b1328f2 | + +-----------------------------------------------------------------------------+ """ return _invoke_function_over_seq_of_columns("java_method", cols) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 0aaebd44a9e..5653591d45e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -575,7 +575,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column): Column = first(e) + def first_value(e: Column): Column = call_function("first_value", e) /** * Aggregate function: returns the first value in a group. @@ -589,9 +589,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def first_value(e: Column, ignoreNulls: Column): Column = withAggregateFunction { - new First(e.expr, ignoreNulls.expr) - } + def first_value(e: Column, ignoreNulls: Column): Column = + call_function("first_value", e, ignoreNulls) /** * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated @@ -848,7 +847,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column): Column = last(e) + def last_value(e: Column): Column = call_function("last_value", e) /** * Aggregate function: returns the last value in a group. @@ -862,9 +861,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def last_value(e: Column, ignoreNulls: Column): Column = withAggregateFunction { - new Last(e.expr, ignoreNulls.expr) - } + def last_value(e: Column, ignoreNulls: Column): Column = + call_function("last_value", e, ignoreNulls) /** * Aggregate function: returns the most frequent value in a group. @@ -1017,9 +1015,8 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = { - percentile_approx(e, percentage, accuracy) - } + def approx_percentile(e: Column, percentage: Column, accuracy: Column): Column = + call_function("approx_percentile", e, percentage, accuracy) /** * Aggregate function: returns the product of all numerical elements in a group. @@ -1052,7 +1049,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def std(e: Column): Column = stddev(e) + def std(e: Column): Column = call_function("std", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1060,7 +1057,7 @@ object functions { * @group agg_funcs * @since 1.6.0 */ - def stddev(e: Column): Column = withAggregateFunction { StddevSamp(e.expr) } + def stddev(e: Column): Column = call_function("stddev", e) /** * Aggregate function: alias for `stddev_samp`. @@ -1330,7 +1327,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def every(e: Column): Column = withAggregateFunction { BoolAnd(e.expr) } + def every(e: Column): Column = call_function("every", e) /** * Aggregate function: returns true if all values of `e` are true. @@ -1346,7 +1343,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def some(e: Column): Column = withAggregateFunction { BoolOr(e.expr) } + def some(e: Column): Column = call_function("some", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1354,7 +1351,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def any(e: Column): Column = withAggregateFunction { BoolOr(e.expr) } + def any(e: Column): Column = call_function("any", e) /** * Aggregate function: returns true if at least one value of `e` is true. @@ -1944,9 +1941,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_avg(e: Column): Column = withAggregateFunction { - Average(e.expr, EvalMode.TRY) - } + def try_avg(e: Column): Column = + call_function("try_avg", e) /** * Returns `dividend``/``divisor`. It always performs floating point division. Its result is @@ -1984,9 +1980,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def try_sum(e: Column): Column = withAggregateFunction { - Sum(e.expr, EvalMode.TRY) - } + def try_sum(e: Column): Column = call_function("try_sum", e) /** * Creates a new struct column. @@ -2081,7 +2075,7 @@ object functions { * @group bitwise_funcs * @since 3.5.0 */ - def getbit(e: Column, pos: Column): Column = bit_get(e, pos) + def getbit(e: Column, pos: Column): Column = call_function("getbit", e, pos) /** * Parses the expression string into the column that it represents, similar to @@ -2385,7 +2379,8 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column, scale: Column): Column = ceil(e, scale) + def ceiling(e: Column, scale: Column): Column = + call_function("ceiling", e, scale) /** * Computes the ceiling of the given value of `e` to 0 decimal places. @@ -2393,7 +2388,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def ceiling(e: Column): Column = ceil(e) + def ceiling(e: Column): Column = call_function("ceiling", e) /** * Convert a number in a string column from one base to another. @@ -2751,7 +2746,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def negative(e: Column): Column = withExpr { UnaryMinus(e.expr) } + def negative(e: Column): Column = call_function("negative", e) /** * Returns Pi. @@ -2979,7 +2974,7 @@ object functions { * @group math_funcs * @since 3.5.0 */ - def sign(e: Column): Column = signum(e) + def sign(e: Column): Column = call_function("sign", e) /** * Computes the signum of the given value. @@ -3184,7 +3179,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def current_schema(): Column = withExpr { CurrentDatabase() } + def current_schema(): Column = call_function("current_schema") /** * Returns the user name of current execution context. @@ -3368,7 +3363,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def user(): Column = withExpr { CurrentUser() } + def user(): Column = call_function("user") /** * Returns an universally unique identifier (UUID) string. The value is returned as a canonical @@ -3638,9 +3633,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def sha(col: Column): Column = withExpr { - Sha1(col.expr) - } + def sha(col: Column): Column = call_function("sha", col) /** * Returns the length of the block being read, or -1 if not available. @@ -3678,9 +3671,8 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def java_method(cols: Column*): Column = withExpr { - CallMethodViaReflection(cols.map(_.expr)) - } + def java_method(cols: Column*): Column = + call_function("java_method", cols: _*) /** * Returns the Spark version. The string contains 2 fields, the first being a release version @@ -3721,9 +3713,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(seed: Column): Column = withExpr { - Rand(seed.expr) - } + def random(seed: Column): Column = call_function("random", seed) /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly @@ -3732,9 +3722,7 @@ object functions { * @group misc_funcs * @since 3.5.0 */ - def random(): Column = withExpr { - new Rand() - } + def random(): Column = call_function("random") /** * Returns the bucket number for the given input column. @@ -4040,7 +4028,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp(str: Column, regexp: Column): Column = rlike(str, regexp) + def regexp(str: Column, regexp: Column): Column = + call_function("regexp", str, regexp) /** * Returns true if `str` matches `regexp`, or false otherwise. @@ -4048,7 +4037,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def regexp_like(str: Column, regexp: Column): Column = rlike(str, regexp) + def regexp_like(str: Column, regexp: Column): Column = + call_function("regexp_like", str, regexp) /** * Returns a count of the number of times that the regular expression pattern `regexp` @@ -4518,9 +4508,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column, len: Column): Column = withExpr { - Substring(str.expr, pos.expr, len.expr) - } + def substr(str: Column, pos: Column, len: Column): Column = + call_function("substr", str, pos, len) /** * Returns the substring of `str` that starts at `pos`, @@ -4529,9 +4518,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def substr(str: Column, pos: Column): Column = withExpr { - new Substring(str.expr, pos.expr) - } + def substr(str: Column, pos: Column): Column = + call_function("substr", str, pos) /** * Extracts a part from a URL. @@ -4559,9 +4547,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def printf(format: Column, arguments: Column*): Column = withExpr { - FormatString((lit(format) +: arguments).map(_.expr): _*) - } + def printf(format: Column, arguments: Column*): Column = + call_function("printf", (format +: arguments): _*) /** * Decodes a `str` in 'application/x-www-form-urlencoded' format @@ -4592,9 +4579,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column, start: Column): Column = withExpr { - StringLocate(substr.expr, str.expr, start.expr) - } + def position(substr: Column, str: Column, start: Column): Column = + call_function("position", substr, str, start) /** * Returns the position of the first occurrence of `substr` in `str` after position `1`. @@ -4603,9 +4589,8 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def position(substr: Column, str: Column): Column = withExpr { - new StringLocate(substr.expr, str.expr) - } + def position(substr: Column, str: Column): Column = + call_function("position", substr, str) /** * Returns a boolean. The value is True if str ends with suffix. @@ -4634,9 +4619,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char(n: Column): Column = withExpr { - Chr(n.expr) - } + def char(n: Column): Column = call_function("char", n) /** * Removes the leading and trailing space characters from `str`. @@ -4700,9 +4683,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def char_length(str: Column): Column = withExpr { - Length(str.expr) - } + def char_length(str: Column): Column = call_function("char_length", str) /** * Returns the character length of string data or number of bytes of binary data. @@ -4712,9 +4693,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def character_length(str: Column): Column = withExpr { - Length(str.expr) - } + def character_length(str: Column): Column = call_function("character_length", str) /** * Returns the ASCII character having the binary equivalent to `n`. @@ -4823,9 +4802,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def lcase(str: Column): Column = withExpr { - Lower(str.expr) - } + def lcase(str: Column): Column = call_function("lcase", str) /** * Returns `str` with all characters changed to uppercase. @@ -4833,9 +4810,7 @@ object functions { * @group string_funcs * @since 3.5.0 */ - def ucase(str: Column): Column = withExpr { - Upper(str.expr) - } + def ucase(str: Column): Column = call_function("ucase", str) /** * Returns the leftmost `len`(`len` can be string type) characters from the string `str`, @@ -4897,7 +4872,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def curdate(): Column = withExpr { CurrentDate() } + def curdate(): Column = call_function("curdate") /** * Returns the current date at the start of query evaluation as a date column. @@ -4999,7 +4974,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def dateadd(start: Column, days: Column): Column = date_add(start, days) + def dateadd(start: Column, days: Column): Column = + call_function("dateadd", start, days) /** * Returns the date that is `days` days before `start` @@ -5064,7 +5040,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_diff(end: Column, start: Column): Column = datediff(end, start) + def date_diff(end: Column, start: Column): Column = + call_function("date_diff", end, start) /** * Create date from the number of `days` since 1970-01-01. @@ -5121,7 +5098,7 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def day(e: Column): Column = dayofmonth(e) + def day(e: Column): Column = call_function("day", e) /** * Extracts the day of the year as an integer from a given date/timestamp/string. @@ -5160,7 +5137,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def date_part(field: Column, source: Column): Column = call_function("date_part", field, source) + def date_part(field: Column, source: Column): Column = + call_function("date_part", field, source) /** * Extracts a part of the date/timestamp or interval source. @@ -5172,7 +5150,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def datepart(field: Column, source: Column): Column = call_function("datepart", field, source) + def datepart(field: Column, source: Column): Column = + call_function("datepart", field, source) /** * Returns the last day of the month which the given date belongs to. @@ -5425,9 +5404,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column, format: Column): Column = withExpr { - new ParseToTimestamp(s.expr, format.expr) - } + def try_to_timestamp(s: Column, format: Column): Column = + call_function("try_to_timestamp", s, format) /** * Parses the `s` to a timestamp. The function always returns null on an invalid @@ -5437,9 +5415,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def try_to_timestamp(s: Column): Column = withExpr { - new ParseToTimestamp(s.expr) - } + def try_to_timestamp(s: Column): Column = + call_function("try_to_timestamp", s) /** * Converts the column into `DateType` by casting rules to `DateType`. @@ -5876,9 +5853,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column, format: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, Some(format.expr), TimestampType) - } + def to_timestamp_ltz(timestamp: Column, format: Column): Column = + call_function("to_timestamp_ltz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5887,9 +5863,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ltz(timestamp: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, None, TimestampType) - } + def to_timestamp_ltz(timestamp: Column): Column = + call_function("to_timestamp_ltz", timestamp) /** * Parses the `timestamp_str` expression with the `format` expression @@ -5898,9 +5873,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column, format: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, Some(format.expr), TimestampNTZType) - } + def to_timestamp_ntz(timestamp: Column, format: Column): Column = + call_function("to_timestamp_ntz", timestamp, format) /** * Parses the `timestamp` expression with the default format to a timestamp without time zone. @@ -5909,9 +5883,8 @@ object functions { * @group datetime_funcs * @since 3.5.0 */ - def to_timestamp_ntz(timestamp: Column): Column = withExpr { - ParseToTimestamp(timestamp.expr, None, TimestampNTZType) - } + def to_timestamp_ntz(timestamp: Column): Column = + call_function("to_timestamp_ntz", timestamp) /** * Returns the UNIX timestamp of the given time. @@ -7016,7 +6989,7 @@ object functions { * @group collection_funcs * @since 3.5.0 */ - def cardinality(e: Column): Column = size(e) + def cardinality(e: Column): Column = call_function("cardinality", e) /** * Sorts the input array for the given column in ascending order, @@ -7074,7 +7047,7 @@ object functions { * @group agg_funcs * @since 3.5.0 */ - def array_agg(e: Column): Column = collect_list(e) + def array_agg(e: Column): Column = call_function("array_agg", e) /** * Returns a random permutation of the given array. @@ -7376,9 +7349,8 @@ object functions { * @group "xml_funcs" * @since 3.5.0 */ - def xpath_number(x: Column, p: Column): Column = withExpr { - XPathDouble(x.expr, p.expr) - } + def xpath_number(x: Column, p: Column): Column = + call_function("xpath_number", x, p) /** * Returns a float value, the value zero if no match is found, @@ -7677,10 +7649,9 @@ object functions { hours: Column, mins: Column, secs: Column, - timezone: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, Some(timezone.expr), dataType = TimestampType) - } + timezone: Column): Column = + call_function("make_timestamp_ltz", + years, months, days, hours, mins, secs, timezone) /** * Create the current timestamp with local time zone from years, months, days, hours, mins and @@ -7696,10 +7667,9 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, dataType = TimestampType) - } + secs: Column): Column = + call_function("make_timestamp_ltz", + years, months, days, hours, mins, secs) /** * Create local date-time from years, months, days, hours, mins, secs fields. If the @@ -7715,10 +7685,9 @@ object functions { days: Column, hours: Column, mins: Column, - secs: Column): Column = withExpr { - MakeTimestamp(years.expr, months.expr, days.expr, hours.expr, - mins.expr, secs.expr, dataType = TimestampNTZType) - } + secs: Column): Column = + call_function("make_timestamp_ntz", + years, months, days, hours, mins, secs) /** * Make year-month interval from years, months. @@ -7785,9 +7754,8 @@ object functions { * @group predicates_funcs * @since 3.5.0 */ - def ifnull(col1: Column, col2: Column): Column = withExpr { - new Nvl(col1.expr, col2.expr) - } + def ifnull(col1: Column, col2: Column): Column = + call_function("ifnull", col1, col2) /** * Returns true if `col` is not null, or false otherwise. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org