gene-db commented on code in PR #48770:
URL: https://github.com/apache/spark/pull/48770#discussion_r1831547607


##########
python/pyspark/sql/tests/test_udf.py:
##########
@@ -334,68 +334,72 @@ def test_udf_with_filter_function(self):
         self.assertEqual(sel.collect(), [Row(key=1, value="1")])
 
     def test_udf_with_variant_input(self):
-        df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-
-        u = udf(lambda u: str(u), StringType())
-        with self.assertRaises(AnalysisException) as ae:
-            df.select(u(col("v"))).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-            messageParameters={"sqlExpr": '"<lambda>(v)"', "dataType": 
"VARIANT"},
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):

Review Comment:
   If arrow is disabled, does it go through the pickle path?



##########
python/pyspark/sql/tests/test_udf.py:
##########
@@ -334,68 +334,72 @@ def test_udf_with_filter_function(self):
         self.assertEqual(sel.collect(), [Row(key=1, value="1")])
 
     def test_udf_with_variant_input(self):
-        df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-
-        u = udf(lambda u: str(u), StringType())
-        with self.assertRaises(AnalysisException) as ae:
-            df.select(u(col("v"))).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-            messageParameters={"sqlExpr": '"<lambda>(v)"', "dataType": 
"VARIANT"},
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
+                u = udf(lambda u: str(u), StringType())
+                expected = [Row(udf="{0}".format(i)) for i in range(10)]
+                result = df.select(u(col("v")).alias("udf")).collect()
+                self.assertEqual(result, expected)
 
     def test_udf_with_complex_variant_input(self):
-        df = self.spark.range(0, 10).selectExpr(
-            "named_struct('v', parse_json(cast(id as string))) struct_of_v"
-        )
-
-        u = udf(lambda u: str(u), StringType())
-
-        with self.assertRaises(AnalysisException) as ae:
-            df.select(u(col("struct_of_v"))).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-            messageParameters={
-                "sqlExpr": '"<lambda>(struct_of_v)"',
-                "dataType": "STRUCT<v: VARIANT NOT NULL>",
-            },
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                df = self.spark.range(0, 10).selectExpr(
+                    "named_struct('v', parse_json(cast(id as string))) 
struct_of_v"
+                )
+                u = udf(lambda u: str(u["v"]), StringType())
+                result = df.select(u(col("struct_of_v"))).collect()
+                expected = [Row(udf="{0}".format(i)) for i in range(10)]
+                self.assertEqual(result, expected)
 
     def test_udf_with_variant_output(self):
-        # The variant value returned corresponds to a JSON string of {"a": 
"b"}.
-        u = udf(
-            lambda: VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 
1, 97])),
-            VariantType(),
-        )
-
-        with self.assertRaises(AnalysisException) as ae:
-            self.spark.range(0, 10).select(u()).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
-            messageParameters={"sqlExpr": '"<lambda>()"', "dataType": 
"VARIANT"},
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                # The variant value returned corresponds to a JSON string of 
{"a": "<a-j>"}.
+                u = udf(
+                    lambda i: VariantVal(
+                        bytes([2, 1, 0, 0, 2, 5, 97 + i]), bytes([1, 1, 0, 1, 
97])
+                    ),
+                    VariantType(),
+                )
+                result = self.spark.range(0, 10).select(
+                    u(col("id")).cast("string").alias("udf")
+                ).collect()
+                expected = [Row(udf=f"{{\"a\":\"{chr(97 + i)}\"}}") for i in 
range(10)]
+                self.assertEqual(result, expected)
 
     def test_udf_with_complex_variant_output(self):
-        # The variant value returned corresponds to a JSON string of {"a": 
"b"}.
-        u = udf(
-            lambda: {"v", VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 
1, 0, 1, 97]))},
-            MapType(StringType(), VariantType()),
-        )
-
-        with self.assertRaises(AnalysisException) as ae:
-            self.spark.range(0, 10).select(u()).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
-            messageParameters={"sqlExpr": '"<lambda>()"', "dataType": 
"MAP<STRING, VARIANT>"},
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                # The variant value returned corresponds to a JSON string of 
{"a": "<a-j>"}.
+                u = udf(
+                    lambda i: {
+                        "v": VariantVal(bytes([2, 1, 0, 0, 2, 5, 97 + i]), 
bytes([1, 1, 0, 1, 97]))
+                    },
+                    MapType(StringType(), VariantType()),
+                )
+                result = self.spark.range(0, 10).select(
+                    u(col("id")).cast("string").alias("udf")
+                ).collect()
+                expected = [Row(udf=f"{{v -> {{\"a\":\"{chr(97 + i)}\"}}}}") 
for i in range(10)]
+                self.assertEqual(result, expected)
+
+    def test_chained_udfs_with_variant(self):

Review Comment:
   can we have another one with changed udfs, with a nested variant, like 
`array<variant>` or `struct<int, variant>` or something?



##########
python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py:
##########
@@ -752,46 +752,87 @@ def check_vectorized_udf_return_scalar(self):
 
     def test_udf_with_variant_input(self):
         df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-        from pyspark.sql.functions import col
 
-        scalar_f = pandas_udf(lambda u: str(u), StringType())
+        scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), 
PandasUDFType.SCALAR)
         iter_f = pandas_udf(
-            lambda it: map(lambda u: str(u), it), StringType(), 
PandasUDFType.SCALAR_ITER
+            lambda it: map(lambda u: u.apply(str), it), StringType(), 
PandasUDFType.SCALAR_ITER
         )
 
+        expected = [Row(udf="{0}".format(i)) for i in range(10)]
+
         for f in [scalar_f, iter_f]:
-            with self.assertRaises(AnalysisException) as ae:
-                df.select(f(col("v"))).collect()
-
-            self.check_error(
-                exception=ae.exception,
-                errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-                messageParameters={
-                    "sqlExpr": '"<lambda>(v)"',
-                    "dataType": "VARIANT",
-                },
-            )
+            result = df.select(f(col("v")).alias("udf")).collect()
+            self.assertEqual(result, expected)
 
     def test_udf_with_variant_output(self):
-        # Corresponds to a JSON string of {"a": "b"}.
-        returned_variant = VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 
1, 0, 1, 97]))
-        scalar_f = pandas_udf(lambda x: returned_variant, VariantType())
+        scalar_f = pandas_udf(
+            lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 
0, 0]))), VariantType()
+        )
         iter_f = pandas_udf(
-            lambda it: map(lambda x: returned_variant, it), VariantType(), 
PandasUDFType.SCALAR_ITER
+            lambda it: map(lambda u: u.apply(
+                lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))
+            ), it),
+            VariantType(),
+            PandasUDFType.SCALAR_ITER
         )
 
+        expected = [Row(udf=i) for i in range(10)]

Review Comment:
   Do we have a `pandas_udf` test case with a nested variant output?



##########
python/pyspark/sql/pandas/types.py:
##########
@@ -1295,6 +1306,16 @@ def convert_udt(value: Any) -> Any:
 
             return convert_udt
 
+        elif isinstance(dt, VariantType):
+            def convert_variant(variant: Any) -> Any:

Review Comment:
   What is this conversion function supposed to return? It is expecting a 
python `VariantVal`, and then returns a dict?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -420,7 +421,11 @@ def __init__(
     def arrow_to_pandas(self, arrow_column):
         import pyarrow.types as types
 
-        if self._df_for_struct and types.is_struct(arrow_column.type):
+        if (
+            self._df_for_struct
+            and types.is_struct(arrow_column.type)
+            and not is_variant(arrow_column.type)
+        ):
             import pandas as pd

Review Comment:
   Why do we want to avoid this path for variant types? Maybe we should leave a 
comment explaining it.



##########
python/pyspark/sql/pandas/types.py:
##########
@@ -221,6 +221,15 @@ def to_arrow_schema(
     return pa.schema(fields)
 
 
+def is_variant(at: "pa.DataType") -> bool:
+    """Check if a PyArrow struct data type represents a variant"""
+    import pyarrow.types as types
+    assert types.is_struct(at)

Review Comment:
   Why is this an assert? Should this just return false if it is not a struct?



##########
sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala:
##########
@@ -146,19 +146,33 @@ private[sql] object ArrowUtils {
           ArrowType.Struct.INSTANCE,
           null,
           Map("variant" -> "true").asJava)
+        val metadataFieldType = new FieldType(
+          false,
+          toArrowType(BinaryType, timeZoneId, largeVarTypes),
+          null,
+          Map("variant" -> "true").asJava
+        )
         new Field(
           name,
           fieldType,
           Seq(
             toArrowField("value", BinaryType, false, timeZoneId, 
largeVarTypes),
-            toArrowField("metadata", BinaryType, false, timeZoneId, 
largeVarTypes)).asJava)
+            new Field("metadata", metadataFieldType, 
Seq.empty[Field].asJava)).asJava)
       case dataType =>
         val fieldType =
           new FieldType(nullable, toArrowType(dataType, timeZoneId, 
largeVarTypes), null)
         new Field(name, fieldType, Seq.empty[Field].asJava)
     }
   }
 
+  def isVariantField(field: Field): Boolean = {
+    assert(field.getType.isInstanceOf[ArrowType.Struct])

Review Comment:
   Should this assert, or should it just return false when it is not a struct?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PythonUDF.scala:
##########
@@ -169,6 +148,23 @@ case class PythonUDAF(
 
   override protected def withNewChildrenInternal(newChildren: 
IndexedSeq[Expression]): PythonUDAF =
     copy(children = newChildren)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val check = super.checkInputDataTypes()
+    if (check.isFailure) {
+      check
+    } else {
+      val exprReturningVariant = children.collectFirst {
+        case e: Expression if 
VariantExpressionEvalUtils.typeContainsVariant(e.dataType) => e
+      }
+      exprReturningVariant match {
+        case Some(e) => TypeCheckResult.DataTypeMismatch(
+          errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE",

Review Comment:
   Does UDAF also use the error class for UDF?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -2530,6 +2531,28 @@ def terminate(self):
             [Row(current=4, total=4), Row(current=13, total=4), 
Row(current=20, total=1)],
         )
 
+    def test_udtf_with_variant_input(self):

Review Comment:
   Are we able to get variations of these 2 tests, with nested variant types?



##########
python/pyspark/sql/pandas/types.py:
##########
@@ -171,7 +171,7 @@ def to_arrow_type(
     elif type(dt) == VariantType:
         fields = [
             pa.field("value", pa.binary(), nullable=False),
-            pa.field("metadata", pa.binary(), nullable=False),
+            pa.field("metadata", pa.binary(), nullable=False, 
metadata={b"variant": b"true"}),

Review Comment:
   Should we comment on the scheme we are using to "tag" an arrow struct as a 
variant?
   
   It looks like we are attaching some metadata to the `metadata` field of the 
struct?



##########
python/pyspark/sql/tests/test_udf.py:
##########
@@ -334,68 +334,72 @@ def test_udf_with_filter_function(self):
         self.assertEqual(sel.collect(), [Row(key=1, value="1")])
 
     def test_udf_with_variant_input(self):
-        df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-
-        u = udf(lambda u: str(u), StringType())
-        with self.assertRaises(AnalysisException) as ae:
-            df.select(u(col("v"))).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-            messageParameters={"sqlExpr": '"<lambda>(v)"', "dataType": 
"VARIANT"},
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
+                u = udf(lambda u: str(u), StringType())
+                expected = [Row(udf="{0}".format(i)) for i in range(10)]
+                result = df.select(u(col("v")).alias("udf")).collect()
+                self.assertEqual(result, expected)
 
     def test_udf_with_complex_variant_input(self):
-        df = self.spark.range(0, 10).selectExpr(
-            "named_struct('v', parse_json(cast(id as string))) struct_of_v"
-        )
-
-        u = udf(lambda u: str(u), StringType())
-
-        with self.assertRaises(AnalysisException) as ae:
-            df.select(u(col("struct_of_v"))).collect()
-
-        self.check_error(
-            exception=ae.exception,
-            errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-            messageParameters={
-                "sqlExpr": '"<lambda>(struct_of_v)"',
-                "dataType": "STRUCT<v: VARIANT NOT NULL>",
-            },
-        )
+        for arrow_enabled in ["false", "true"]:
+            with self.sql_conf({"spark.sql.execution.pythonUDF.arrow.enabled": 
arrow_enabled}):
+                df = self.spark.range(0, 10).selectExpr(
+                    "named_struct('v', parse_json(cast(id as string))) 
struct_of_v"
+                )
+                u = udf(lambda u: str(u["v"]), StringType())
+                result = df.select(u(col("struct_of_v"))).collect()

Review Comment:
   can we also add unit tests with inputs of nested types
   - `array<variant>`
   - `map<string, variant>`
   



##########
python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py:
##########
@@ -752,46 +752,87 @@ def check_vectorized_udf_return_scalar(self):
 
     def test_udf_with_variant_input(self):
         df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")

Review Comment:
   Do we have a `pandas_udf` test case with a nested variant input?



##########
python/pyspark/sql/pandas/types.py:
##########
@@ -221,6 +221,15 @@ def to_arrow_schema(
     return pa.schema(fields)
 
 
+def is_variant(at: "pa.DataType") -> bool:
+    """Check if a PyArrow struct data type represents a variant"""
+    import pyarrow.types as types
+    assert types.is_struct(at)
+
+    return any((field.name == "metadata" and b"variant" in field.metadata and

Review Comment:
   Shouldn't we check that the fields are `metadata` and `value`?



##########
python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py:
##########
@@ -752,46 +752,87 @@ def check_vectorized_udf_return_scalar(self):
 
     def test_udf_with_variant_input(self):
         df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-        from pyspark.sql.functions import col
 
-        scalar_f = pandas_udf(lambda u: str(u), StringType())
+        scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), 
PandasUDFType.SCALAR)
         iter_f = pandas_udf(
-            lambda it: map(lambda u: str(u), it), StringType(), 
PandasUDFType.SCALAR_ITER
+            lambda it: map(lambda u: u.apply(str), it), StringType(), 
PandasUDFType.SCALAR_ITER
         )
 
+        expected = [Row(udf="{0}".format(i)) for i in range(10)]
+
         for f in [scalar_f, iter_f]:
-            with self.assertRaises(AnalysisException) as ae:
-                df.select(f(col("v"))).collect()
-
-            self.check_error(
-                exception=ae.exception,
-                errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
-                messageParameters={
-                    "sqlExpr": '"<lambda>(v)"',
-                    "dataType": "VARIANT",
-                },
-            )
+            result = df.select(f(col("v")).alias("udf")).collect()
+            self.assertEqual(result, expected)
 
     def test_udf_with_variant_output(self):
-        # Corresponds to a JSON string of {"a": "b"}.
-        returned_variant = VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 
1, 0, 1, 97]))
-        scalar_f = pandas_udf(lambda x: returned_variant, VariantType())
+        scalar_f = pandas_udf(
+            lambda u: u.apply(lambda i: VariantVal(bytes([12, i]), bytes([1, 
0, 0]))), VariantType()
+        )
         iter_f = pandas_udf(
-            lambda it: map(lambda x: returned_variant, it), VariantType(), 
PandasUDFType.SCALAR_ITER
+            lambda it: map(lambda u: u.apply(
+                lambda i: VariantVal(bytes([12, i]), bytes([1, 0, 0]))
+            ), it),
+            VariantType(),
+            PandasUDFType.SCALAR_ITER
         )
 
+        expected = [Row(udf=i) for i in range(10)]
+
         for f in [scalar_f, iter_f]:
-            with self.assertRaises(AnalysisException) as ae:
-                self.spark.range(0, 10).select(f()).collect()
-
-            self.check_error(
-                exception=ae.exception,
-                errorClass="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
-                messageParameters={
-                    "sqlExpr": '"<lambda>()"',
-                    "dataType": "VARIANT",
-                },
-            )
+            # with self.assertRaises(AnalysisException) as ae:

Review Comment:
   remove?



##########
sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala:
##########
@@ -146,19 +146,33 @@ private[sql] object ArrowUtils {
           ArrowType.Struct.INSTANCE,
           null,
           Map("variant" -> "true").asJava)
+        val metadataFieldType = new FieldType(

Review Comment:
   Do we have 2 ways of tagging something a variant? The line above is some 
metadata, but we also have this other metadata thing? This is confusing, so 
could you explain this more with some comments? Do we need both schemes?



##########
python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py:
##########
@@ -752,46 +752,87 @@ def check_vectorized_udf_return_scalar(self):
 
     def test_udf_with_variant_input(self):
         df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
-        from pyspark.sql.functions import col
 
-        scalar_f = pandas_udf(lambda u: str(u), StringType())
+        scalar_f = pandas_udf(lambda u: u.apply(str), StringType(), 
PandasUDFType.SCALAR)

Review Comment:
   Does `pandas_udf` go through the same path as an arrow udf path?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -505,7 +510,12 @@ def _create_batch(self, series):
 
         arrs = []
         for s, t in series:
-            if self._struct_in_pandas == "dict" and t is not None and 
pa.types.is_struct(t):
+            if (
+                self._struct_in_pandas == "dict"
+                and t is not None
+                and pa.types.is_struct(t)
+                and not is_variant(t)
+            ):
                 # A pandas UDF should return pd.DataFrame when the return type 
is a struct type.

Review Comment:
   Why do we want to avoid this path for variant types? Maybe we should leave a 
comment explaining it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to