This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e3ba9cf0403 [SPARK-45620][PYTHON] Fix user-facing APIs related to
Python UDTF to use camelCase
e3ba9cf0403 is described below
commit e3ba9cf0403ade734f87621472088687e533b2cd
Author: Takuya UESHIN <[email protected]>
AuthorDate: Mon Oct 23 10:35:30 2023 +0900
[SPARK-45620][PYTHON] Fix user-facing APIs related to Python UDTF to use
camelCase
### What changes were proposed in this pull request?
Fix user-facing APIs related to Python UDTF to use camelCase.
### Why are the changes needed?
To keep the naming convention for user-facing APIs.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Updated the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43470 from ueshin/issues/SPARK-45620/field_names.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/docs/source/user_guide/sql/python_udtf.rst | 22 +++---
python/pyspark/sql/functions.py | 12 ++--
python/pyspark/sql/tests/test_udtf.py | 84 +++++++++++------------
python/pyspark/sql/udtf.py | 24 +++----
python/pyspark/sql/worker/analyze_udtf.py | 12 ++--
5 files changed, 77 insertions(+), 77 deletions(-)
diff --git a/python/docs/source/user_guide/sql/python_udtf.rst
b/python/docs/source/user_guide/sql/python_udtf.rst
index fb42644dc70..0e0c6e28578 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -77,29 +77,29 @@ To implement a Python UDTF, you first need to define a
class implementing the me
the particular UDTF call under consideration. Each parameter is an
instance of the
`AnalyzeArgument` class, which contains fields including the
provided argument's data
type and value (in the case of literal scalar arguments only). For
table arguments, the
- `is_table` field is set to true and the `data_type` field is a
StructType representing
+ `isTable` field is set to true and the `dataType` field is a
StructType representing
the table's column types:
- data_type: DataType
+ dataType: DataType
value: Optional[Any]
- is_table: bool
+ isTable: bool
This method returns an instance of the `AnalyzeResult` class which
includes the result
table's schema as a StructType. If the UDTF accepts an input table
argument, then the
`AnalyzeResult` can also include a requested way to partition the
rows of the input
- table across several UDTF calls. If `with_single_partition` is set
to True, the query
+ table across several UDTF calls. If `withSinglePartition` is set
to True, the query
planner will arrange a repartitioning operation from the previous
execution stage such
that all rows of the input table are consumed by the `eval` method
from exactly one
- instance of the UDTF class. On the other hand, if the
`partition_by` list is non-empty,
+ instance of the UDTF class. On the other hand, if the
`partitionBy` list is non-empty,
the query planner will arrange a repartitioning such that all rows
with each unique
combination of values of the partitioning columns are consumed by
a separate unique
- instance of the UDTF class. If `order_by` is non-empty, this
specifies the requested
+ instance of the UDTF class. If `orderBy` is non-empty, this
specifies the requested
ordering of rows within each partition.
schema: StructType
- with_single_partition: bool = False
- partition_by: Sequence[PartitioningColumn] =
field(default_factory=tuple)
- order_by: Sequence[OrderingColumn] =
field(default_factory=tuple)
+ withSinglePartition: bool = False
+ partitionBy: Sequence[PartitioningColumn] =
field(default_factory=tuple)
+ orderBy: Sequence[OrderingColumn] =
field(default_factory=tuple)
Examples
--------
@@ -116,7 +116,7 @@ To implement a Python UDTF, you first need to define a
class implementing the me
>>> def analyze(self, *args) -> AnalyzeResult:
... assert len(args) == 1, "This function accepts one argument
only"
- ... assert args[0].data_type == StringType(), "Only string
arguments are supported"
+ ... assert args[0].dataType == StringType(), "Only string
arguments are supported"
... text = args[0]
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
@@ -128,7 +128,7 @@ To implement a Python UDTF, you first need to define a
class implementing the me
>>> def analyze(self, **kwargs) -> AnalyzeResult:
... assert len(kwargs) == 1, "This function accepts one
argument only"
... assert "text" in kwargs, "An argument named 'text' is
required"
- ... assert kwargs["text"].data_type == StringType(), "Only
strings are supported"
+ ... assert kwargs["text"].dataType == StringType(), "Only
strings are supported"
... text = args["text"]
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 74ecc77e7d7..05c22685b09 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -17184,9 +17184,9 @@ def udtf(
- The number and order of arguments are the same as the UDTF inputs
- Each argument is a :class:`pyspark.sql.udtf.AnalyzeArgument`, containing:
- - data_type: DataType
+ - dataType: DataType
- value: Any: the calculated value if the argument is foldable;
otherwise None
- - is_table: bool: True if the argument is a table argument
+ - isTable: bool: True if the argument is a table argument
and return a :class:`pyspark.sql.udtf.AnalyzeResult`, containing.
@@ -17198,7 +17198,7 @@ def udtf(
... class TestUDTFWithAnalyze:
... @staticmethod
... def analyze(a: AnalyzeArgument, b: AnalyzeArgument) ->
AnalyzeResult:
- ... return AnalyzeResult(StructType().add("a",
a.data_type).add("b", b.data_type))
+ ... return AnalyzeResult(StructType().add("a",
a.dataType).add("b", b.dataType))
...
... def eval(self, a, b):
... yield a, b
@@ -17219,9 +17219,9 @@ def udtf(
... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs:
AnalyzeArgument
... ) -> AnalyzeResult:
... return AnalyzeResult(
- ... StructType().add("a", a.data_type)
- ... .add("b", b.data_type)
- ... .add("x", kwargs["x"].data_type)
+ ... StructType().add("a", a.dataType)
+ ... .add("b", b.dataType)
+ ... .add("x", kwargs["x"].dataType)
... )
...
... def eval(self, a, b, **kwargs):
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 34972a5d802..3beb916de66 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -1244,10 +1244,10 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
assert isinstance(a, AnalyzeArgument)
- assert isinstance(a.data_type, DataType)
+ assert isinstance(a.dataType, DataType)
assert a.value is not None
- assert a.is_table is False
- return AnalyzeResult(StructType().add("a", a.data_type))
+ assert a.isTable is False
+ return AnalyzeResult(StructType().add("a", a.dataType))
def eval(self, a):
yield a,
@@ -1333,7 +1333,7 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument, b: AnalyzeArgument) ->
AnalyzeResult:
- return AnalyzeResult(StructType().add("a",
a.data_type).add("b", b.data_type))
+ return AnalyzeResult(StructType().add("a",
a.dataType).add("b", b.dataType))
def eval(self, a, b):
yield a, b
@@ -1364,7 +1364,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(*args: AnalyzeArgument) -> AnalyzeResult:
return AnalyzeResult(
- StructType([StructField(f"col{i}", a.data_type) for i, a
in enumerate(args)])
+ StructType([StructField(f"col{i}", a.dataType) for i, a in
enumerate(args)])
)
def eval(self, *args):
@@ -1397,10 +1397,10 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
assert isinstance(a, AnalyzeArgument)
- assert isinstance(a.data_type, StructType)
+ assert isinstance(a.dataType, StructType)
assert a.value is None
- assert a.is_table is True
- return AnalyzeResult(StructType().add("a",
a.data_type[0].dataType))
+ assert a.isTable is True
+ return AnalyzeResult(StructType().add("a",
a.dataType[0].dataType))
def eval(self, a: Row):
if a["id"] > 5:
@@ -1417,9 +1417,9 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- assert isinstance(a.data_type, StructType)
- assert a.is_table is True
- return AnalyzeResult(a.data_type.add("is_even", BooleanType()))
+ assert isinstance(a.dataType, StructType)
+ assert a.isTable is True
+ return AnalyzeResult(a.dataType.add("is_even", BooleanType()))
def eval(self, a: Row):
yield a["id"], a["id"] % 2 == 0
@@ -1449,11 +1449,11 @@ class BaseUDTFTestsMixin:
if n.value is None or not isinstance(n.value, int) or (n.value
< 1 or n.value > 10):
raise Exception("The first argument must be a scalar
integer between 1 and 10")
- if row.is_table is False:
+ if row.isTable is False:
raise Exception("The second argument must be a table
argument")
- assert isinstance(row.data_type, StructType)
- return AnalyzeResult(row.data_type)
+ assert isinstance(row.dataType, StructType)
+ return AnalyzeResult(row.dataType)
def eval(self, n: int, row: Row):
for _ in range(n):
@@ -1604,7 +1604,7 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return AnalyzeResult(StructType().add("a", a.data_type))
+ return AnalyzeResult(StructType().add("a", a.dataType))
def eval(self, a):
yield a,
@@ -1619,7 +1619,7 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument, b: AnalyzeArgument) ->
AnalyzeResult:
- return AnalyzeResult(StructType().add("a",
a.data_type).add("b", b.data_type))
+ return AnalyzeResult(StructType().add("a",
a.dataType).add("b", b.dataType))
def eval(self, a):
yield a, a + 1
@@ -1675,7 +1675,7 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return AnalyzeResult(StructType().add(colname.value,
a.data_type))
+ return AnalyzeResult(StructType().add(colname.value,
a.dataType))
def eval(self, a):
assert colname.value == "col1"
@@ -1700,7 +1700,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
test_accum.add(1)
- return AnalyzeResult(StructType().add("col1", a.data_type))
+ return AnalyzeResult(StructType().add("col1", a.dataType))
def eval(self, a):
test_accum.add(10)
@@ -1739,7 +1739,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type))
+ return
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.dataType))
def eval(self, a):
assert TestUDTF.call_my_func() == "col1"
@@ -1779,7 +1779,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.data_type))
+ return
AnalyzeResult(StructType().add(TestUDTF.call_my_func(), a.dataType))
def eval(self, a):
assert TestUDTF.call_my_func() == "col1"
@@ -1826,7 +1826,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return
AnalyzeResult(StructType().add(TestUDTF.read_my_archive(), a.data_type))
+ return
AnalyzeResult(StructType().add(TestUDTF.read_my_archive(), a.dataType))
def eval(self, a):
assert TestUDTF.read_my_archive() == "col1"
@@ -1867,7 +1867,7 @@ class BaseUDTFTestsMixin:
@staticmethod
def analyze(a: AnalyzeArgument) -> AnalyzeResult:
- return
AnalyzeResult(StructType().add(TestUDTF.read_my_file(), a.data_type))
+ return
AnalyzeResult(StructType().add(TestUDTF.read_my_file(), a.dataType))
def eval(self, a):
assert TestUDTF.read_my_file() == "col1"
@@ -1967,15 +1967,15 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
- assert isinstance(kwargs["a"].data_type, IntegerType)
+ assert isinstance(kwargs["a"].dataType, IntegerType)
assert kwargs["a"].value == 10
- assert not kwargs["a"].is_table
- assert isinstance(kwargs["b"].data_type, StringType)
+ assert not kwargs["a"].isTable
+ assert isinstance(kwargs["b"].dataType, StringType)
assert kwargs["b"].value == "x"
- assert not kwargs["b"].is_table
+ assert not kwargs["b"].isTable
return AnalyzeResult(
StructType(
- [StructField(key, arg.data_type) for key, arg in
sorted(kwargs.items())]
+ [StructField(key, arg.dataType) for key, arg in
sorted(kwargs.items())]
)
)
@@ -2000,7 +2000,7 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a, b):
- return AnalyzeResult(StructType().add("a", a.data_type))
+ return AnalyzeResult(StructType().add("a", a.dataType))
def eval(self, a, b):
yield a,
@@ -2028,18 +2028,18 @@ class BaseUDTFTestsMixin:
class TestUDTF:
@staticmethod
def analyze(a: AnalyzeArgument, b: Optional[AnalyzeArgument] =
None):
- assert isinstance(a.data_type, IntegerType)
+ assert isinstance(a.dataType, IntegerType)
assert a.value == 10
- assert not a.is_table
+ assert not a.isTable
if b is not None:
- assert isinstance(b.data_type, StringType)
+ assert isinstance(b.dataType, StringType)
assert b.value == "z"
- assert not b.is_table
- schema = StructType().add("a", a.data_type)
+ assert not b.isTable
+ schema = StructType().add("a", a.dataType)
if b is None:
return AnalyzeResult(schema.add("b", IntegerType()))
else:
- return AnalyzeResult(schema.add("b", b.data_type))
+ return AnalyzeResult(schema.add("b", b.dataType))
def eval(self, a, b=100):
yield a, b
@@ -2298,8 +2298,8 @@ class BaseUDTFTestsMixin:
.add("count", IntegerType())
.add("total", IntegerType())
.add("last", IntegerType()),
- with_single_partition=True,
- order_by=[OrderingColumn("input"),
OrderingColumn("partition_col")],
+ withSinglePartition=True,
+ orderBy=[OrderingColumn("input"),
OrderingColumn("partition_col")],
)
def eval(self, row: Row):
@@ -2352,8 +2352,8 @@ class BaseUDTFTestsMixin:
.add("count", IntegerType())
.add("total", IntegerType())
.add("last", IntegerType()),
- partition_by=[PartitioningColumn("partition_col")],
- order_by=[
+ partitionBy=[PartitioningColumn("partition_col")],
+ orderBy=[
OrderingColumn(name="input", ascending=True,
overrideNullsFirst=False)
],
)
@@ -2433,16 +2433,16 @@ class BaseUDTFTestsMixin:
def analyze(argument, _):
if (
argument.value is None
- or argument.is_table
+ or argument.isTable
or not isinstance(argument.value, str)
or len(argument.value) == 0
):
raise Exception("The first argument must be non-empty
string")
- assert argument.data_type == StringType()
- assert not argument.is_table
+ assert argument.dataType == StringType()
+ assert not argument.isTable
return AnalyzeResultWithBuffer(
schema=StructType().add("total",
IntegerType()).add("buffer", StringType()),
- with_single_partition=True,
+ withSinglePartition=True,
buffer=argument.value,
)
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 26ce68111db..aac212ffde9 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -48,17 +48,17 @@ class AnalyzeArgument:
Parameters
----------
- data_type : :class:`DataType`
+ dataType : :class:`DataType`
The argument's data type
value : Optional[Any]
The calculated value if the argument is foldable; otherwise None
- is_table : bool
+ isTable : bool
If True, the argument is a table argument.
"""
- data_type: DataType
+ dataType: DataType
value: Optional[Any]
- is_table: bool
+ isTable: bool
@dataclass(frozen=True)
@@ -97,25 +97,25 @@ class AnalyzeResult:
----------
schema : :class:`StructType`
The schema that the Python UDTF will return.
- with_single_partition : bool
+ withSinglePartition : bool
If true, the UDTF is specifying for Catalyst to repartition all rows
of the input TABLE
argument to one collection for consumption by exactly one instance of
the correpsonding
UDTF class.
- partition_by : Sequence[PartitioningColumn]
+ partitionBy : Sequence[PartitioningColumn]
If non-empty, this is a sequence of columns that the UDTF is
specifying for Catalyst to
partition the input TABLE argument by. In this case, calls to the UDTF
may not include any
explicit PARTITION BY clause, in which case Catalyst will return an
error. This option is
- mutually exclusive with 'with_single_partition'.
- order_by: Sequence[OrderingColumn]
+ mutually exclusive with 'withSinglePartition'.
+ orderBy: Sequence[OrderingColumn]
If non-empty, this is a sequence of columns that the UDTF is
specifying for Catalyst to
- sort the input TABLE argument by. Note that the 'partition_by' list
must also be non-empty
+ sort the input TABLE argument by. Note that the 'partitionBy' list
must also be non-empty
in this case.
"""
schema: StructType
- with_single_partition: bool = False
- partition_by: Sequence[PartitioningColumn] = field(default_factory=tuple)
- order_by: Sequence[OrderingColumn] = field(default_factory=tuple)
+ withSinglePartition: bool = False
+ partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
+ orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
def _create_udtf(
diff --git a/python/pyspark/sql/worker/analyze_udtf.py
b/python/pyspark/sql/worker/analyze_udtf.py
index 9e84b880fc9..de484c9cf94 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -83,7 +83,7 @@ def read_arguments(infile: IO) ->
Tuple[List[AnalyzeArgument], Dict[str, Analyze
else:
value = None
is_table = read_bool(infile) # is table argument
- argument = AnalyzeArgument(data_type=dt, value=value,
is_table=is_table)
+ argument = AnalyzeArgument(dataType=dt, value=value, isTable=is_table)
is_named_arg = read_bool(infile)
if is_named_arg:
@@ -129,14 +129,14 @@ def main(infile: IO, outfile: IO) -> None:
# Return the pickled 'AnalyzeResult' class instance.
pickleSer._write_with_length(result, outfile)
# Return whether the "with single partition" property is requested.
- write_int(1 if result.with_single_partition else 0, outfile)
+ write_int(1 if result.withSinglePartition else 0, outfile)
# Return the list of partitioning columns, if any.
- write_int(len(result.partition_by), outfile)
- for partitioning_col in result.partition_by:
+ write_int(len(result.partitionBy), outfile)
+ for partitioning_col in result.partitionBy:
write_with_length(partitioning_col.name.encode("utf-8"), outfile)
# Return the requested input table ordering, if any.
- write_int(len(result.order_by), outfile)
- for ordering_col in result.order_by:
+ write_int(len(result.orderBy), outfile)
+ for ordering_col in result.orderBy:
write_with_length(ordering_col.name.encode("utf-8"), outfile)
write_int(1 if ordering_col.ascending else 0, outfile)
if ordering_col.overrideNullsFirst is None:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]