This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 536445cc954 [SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__` 
accept input column
536445cc954 is described below

commit 536445cc9544e22ca8b988684c0bb3df4bbed77e
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Nov 21 11:24:27 2023 -0800

    [SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__` accept input column
    
    ### What changes were proposed in this pull request?
    Make `Column.__getitem__` accept input column
    
    ### Why are the changes needed?
    
    `Column.__getitem__` should accept column as input
    
    ```
    In [1]: from pyspark.sql.functions import col,lit,create_map
       ...: from itertools import chain
       ...:
       ...: mapping = {
       ...:   'A': '20',
       ...:   'B': '28',
       ...:   'C': '34'
       ...: }
       ...:
       ...: x = [['A','10'],['B','14'],['C','17']]
       ...: df = spark.createDataFrame(data=x, schema = ["key", "value"])
       ...:
       ...: mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])
       ...: df = df.withColumn("square_value", mapping_expr[col("key")])
       ...: df.show()
    ---------------------------------------------------------------------------
    PySparkTypeError                          Traceback (most recent call last)
    Cell In[1], line 14
         11 df = spark.createDataFrame(data=x, schema = ["key", "value"])
         13 mapping_expr = create_map([lit(x) for x in chain(*mapping.items())])
    ---> 14 df = df.withColumn("square_value", mapping_expr[col("key")])
         15 df.show()
    
    File ~/Dev/spark/python/pyspark/sql/connect/column.py:465, in 
Column.__getitem__(self, k)
        463     return self.substr(k.start, k.stop)
        464 else:
    --> 465     return Column(UnresolvedExtractValue(self._expr, 
LiteralExpression._from_value(k)))
    
    File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:336, in 
LiteralExpression._from_value(cls, value)
        334 classmethod
        335 def _from_value(cls, value: Any) -> "LiteralExpression":
    --> 336     return LiteralExpression(value=value, 
dataType=LiteralExpression._infer_type(value))
    
    File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:329, in 
LiteralExpression._infer_type(cls, value)
        323         raise PySparkTypeError(
        324             error_class="CANNOT_INFER_ARRAY_TYPE",
        325             message_parameters={},
        326         )
        327     return ArrayType(LiteralExpression._infer_type(first), True)
    --> 329 raise PySparkTypeError(
        330     error_class="UNSUPPORTED_DATA_TYPE",
        331     message_parameters={"data_type": type(value).__name__},
        332 )
    
    PySparkTypeError: [UNSUPPORTED_DATA_TYPE] Unsupported DataType `Column`.
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added ut
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43930 from zhengruifeng/connect_column_getitem.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 python/pyspark/sql/connect/column.py    |  2 ++
 python/pyspark/sql/tests/test_column.py | 11 +++++++++++
 2 files changed, 13 insertions(+)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 19ec93151f0..a6d9ca8a2ff 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -461,6 +461,8 @@ class Column:
                     message_parameters={},
                 )
             return self.substr(k.start, k.stop)
+        elif isinstance(k, Column):
+            return Column(UnresolvedExtractValue(self._expr, k._expr))
         else:
             return Column(UnresolvedExtractValue(self._expr, 
LiteralExpression._from_value(k)))
 
diff --git a/python/pyspark/sql/tests/test_column.py 
b/python/pyspark/sql/tests/test_column.py
index db6cd321cf2..622c1f7b210 100644
--- a/python/pyspark/sql/tests/test_column.py
+++ b/python/pyspark/sql/tests/test_column.py
@@ -16,7 +16,9 @@
 # limitations under the License.
 #
 
+from itertools import chain
 from pyspark.sql import Column, Row
+from pyspark.sql import functions as sf
 from pyspark.sql.types import StructType, StructField, LongType
 from pyspark.errors import AnalysisException, PySparkTypeError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -207,6 +209,15 @@ class ColumnTestsMixin:
 
         self.assertTrue("e" not in result["a2"]["d"] and "f" in 
result["a2"]["d"])
 
+    def test_getitem_column(self):
+        mapping = {"A": "20", "B": "28", "C": "34"}
+        mapping_expr = sf.create_map([sf.lit(x) for x in 
chain(*mapping.items())])
+        df = self.spark.createDataFrame(
+            data=[["A", "10"], ["B", "14"], ["C", "17"]],
+            schema=["key", "value"],
+        ).withColumn("square_value", mapping_expr[sf.col("key")])
+        self.assertEqual(df.count(), 3)
+
 
 class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase):
     pass


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to