This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 536445cc954 [SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__`
accept input column
536445cc954 is described below
commit 536445cc9544e22ca8b988684c0bb3df4bbed77e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Nov 21 11:24:27 2023 -0800
[SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__` accept input column
### What changes were proposed in this pull request?
Make `Column.__getitem__` accept input column
### Why are the changes needed?
`Column.__getitem__` should accept column as input
```
In [1]: from pyspark.sql.functions import col,lit,create_map
...: from itertools import chain
...:
...: mapping = {
...: 'A': '20',
...: 'B': '28',
...: 'C': '34'
...: }
...:
...: x = [['A','10'],['B','14'],['C','17']]
...: df = spark.createDataFrame(data=x, schema = ["key", "value"])
...:
...: mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])
...: df = df.withColumn("square_value", mapping_expr[col("key")])
...: df.show()
---------------------------------------------------------------------------
PySparkTypeError Traceback (most recent call last)
Cell In[1], line 14
11 df = spark.createDataFrame(data=x, schema = ["key", "value"])
13 mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])
---> 14 df = df.withColumn("square_value", mapping_expr[col("key")])
15 df.show()
File ~/Dev/spark/python/pyspark/sql/connect/column.py:465, in
Column.__getitem__(self, k)
463 return self.substr(k.start, k.stop)
464 else:
--> 465 return Column(UnresolvedExtractValue(self._expr,
LiteralExpression._from_value(k)))
File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:336, in
LiteralExpression._from_value(cls, value)
334 classmethod
335 def _from_value(cls, value: Any) -> "LiteralExpression":
--> 336 return LiteralExpression(value=value,
dataType=LiteralExpression._infer_type(value))
File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:329, in
LiteralExpression._infer_type(cls, value)
323 raise PySparkTypeError(
324 error_class="CANNOT_INFER_ARRAY_TYPE",
325 message_parameters={},
326 )
327 return ArrayType(LiteralExpression._infer_type(first), True)
--> 329 raise PySparkTypeError(
330 error_class="UNSUPPORTED_DATA_TYPE",
331 message_parameters={"data_type": type(value).__name__},
332 )
PySparkTypeError: [UNSUPPORTED_DATA_TYPE] Unsupported DataType `Column`.
```
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43930 from zhengruifeng/connect_column_getitem.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/connect/column.py | 2 ++
python/pyspark/sql/tests/test_column.py | 11 +++++++++++
2 files changed, 13 insertions(+)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 19ec93151f0..a6d9ca8a2ff 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -461,6 +461,8 @@ class Column:
message_parameters={},
)
return self.substr(k.start, k.stop)
+ elif isinstance(k, Column):
+ return Column(UnresolvedExtractValue(self._expr, k._expr))
else:
return Column(UnresolvedExtractValue(self._expr,
LiteralExpression._from_value(k)))
diff --git a/python/pyspark/sql/tests/test_column.py
b/python/pyspark/sql/tests/test_column.py
index db6cd321cf2..622c1f7b210 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -16,7 +16,9 @@
# limitations under the License.
#
+from itertools import chain
from pyspark.sql import Column, Row
+from pyspark.sql import functions as sf
from pyspark.sql.types import StructType, StructField, LongType
from pyspark.errors import AnalysisException, PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -207,6 +209,15 @@ class ColumnTestsMixin:
self.assertTrue("e" not in result["a2"]["d"] and "f" in
result["a2"]["d"])
+ def test_getitem_column(self):
+ mapping = {"A": "20", "B": "28", "C": "34"}
+ mapping_expr = sf.create_map([sf.lit(x) for x in
chain(*mapping.items())])
+ df = self.spark.createDataFrame(
+ data=[["A", "10"], ["B", "14"], ["C", "17"]],
+ schema=["key", "value"],
+ ).withColumn("square_value", mapping_expr[sf.col("key")])
+ self.assertEqual(df.count(), 3)
+
class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]