This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 536445cc954 [SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__` accept input column 536445cc954 is described below commit 536445cc9544e22ca8b988684c0bb3df4bbed77e Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Nov 21 11:24:27 2023 -0800 [SPARK-46028][CONNECT][PYTHON] Make `Column.__getitem__` accept input column ### What changes were proposed in this pull request? Make `Column.__getitem__` accept input column ### Why are the changes needed? `Column.__getitem__` should accept column as input ``` In [1]: from pyspark.sql.functions import col,lit,create_map ...: from itertools import chain ...: ...: mapping = { ...: 'A': '20', ...: 'B': '28', ...: 'C': '34' ...: } ...: ...: x = [['A','10'],['B','14'],['C','17']] ...: df = spark.createDataFrame(data=x, schema = ["key", "value"]) ...: ...: mapping_expr = create_map([lit(x) for x in chain(*mapping.items())]) ...: df = df.withColumn("square_value", mapping_expr[col("key")]) ...: df.show() --------------------------------------------------------------------------- PySparkTypeError Traceback (most recent call last) Cell In[1], line 14 11 df = spark.createDataFrame(data=x, schema = ["key", "value"]) 13 mapping_expr = create_map([lit(x) for x in chain(*mapping.items())]) ---> 14 df = df.withColumn("square_value", mapping_expr[col("key")]) 15 df.show() File ~/Dev/spark/python/pyspark/sql/connect/column.py:465, in Column.__getitem__(self, k) 463 return self.substr(k.start, k.stop) 464 else: --> 465 return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:336, in LiteralExpression._from_value(cls, value) 334 classmethod 335 def _from_value(cls, value: Any) -> "LiteralExpression": --> 336 return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value)) File ~/Dev/spark/python/pyspark/sql/connect/expressions.py:329, in LiteralExpression._infer_type(cls, value) 323 raise PySparkTypeError( 324 error_class="CANNOT_INFER_ARRAY_TYPE", 325 message_parameters={}, 326 ) 327 return ArrayType(LiteralExpression._infer_type(first), True) --> 329 raise PySparkTypeError( 330 error_class="UNSUPPORTED_DATA_TYPE", 331 message_parameters={"data_type": type(value).__name__}, 332 ) PySparkTypeError: [UNSUPPORTED_DATA_TYPE] Unsupported DataType `Column`. ``` ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? added ut ### Was this patch authored or co-authored using generative AI tooling? no Closes #43930 from zhengruifeng/connect_column_getitem. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- python/pyspark/sql/connect/column.py | 2 ++ python/pyspark/sql/tests/test_column.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index 19ec93151f0..a6d9ca8a2ff 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -461,6 +461,8 @@ class Column: message_parameters={}, ) return self.substr(k.start, k.stop) + elif isinstance(k, Column): + return Column(UnresolvedExtractValue(self._expr, k._expr)) else: return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index db6cd321cf2..622c1f7b210 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -16,7 +16,9 @@ # limitations under the License. # +from itertools import chain from pyspark.sql import Column, Row +from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, LongType from pyspark.errors import AnalysisException, PySparkTypeError from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -207,6 +209,15 @@ class ColumnTestsMixin: self.assertTrue("e" not in result["a2"]["d"] and "f" in result["a2"]["d"]) + def test_getitem_column(self): + mapping = {"A": "20", "B": "28", "C": "34"} + mapping_expr = sf.create_map([sf.lit(x) for x in chain(*mapping.items())]) + df = self.spark.createDataFrame( + data=[["A", "10"], ["B", "14"], ["C", "17"]], + schema=["key", "value"], + ).withColumn("square_value", mapping_expr[sf.col("key")]) + self.assertEqual(df.count(), 3) + class ColumnTests(ColumnTestsMixin, ReusedSQLTestCase): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org