cloud-fan commented on code in PR #56505:
URL: https://github.com/apache/spark/pull/56505#discussion_r3430465052


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala:
##########
@@ -836,8 +865,26 @@ case object SparkShreddingUtils {
     val resultRow = new GenericInternalRow(numFields)
     var fieldIdx = 0
     while (fieldIdx < numFields) {
-      resultRow.update(fieldIdx, extractField(inputRow, topLevelMetadata, 
schema,
-        fields(fieldIdx).path, fields(fieldIdx).reader))
+      val field = fields(fieldIdx)
+      if (field.isCastError) {
+        // Filled by the paired data field on failure; left null otherwise.
+      } else if (field.castErrorOrdinal >= 0) {
+        try {
+          val value = extractField(inputRow, topLevelMetadata, schema, 
field.path, field.reader)
+          resultRow.update(fieldIdx, value)
+        } catch {
+          case e: SparkRuntimeException if e.getCondition == 
"INVALID_VARIANT_CAST" =>
+            // Recover the offending value from the error's `value` message 
parameter so the

Review Comment:
   The recovered `value` matches the eager path, but the re-raised `dataType` 
does not. For a composite target (`array`/`struct`/`map`) where an *inner* 
element fails, the eager cast raises `INVALID_VARIANT_CAST` with the 
**element** type (e.g. `INT`), since `VariantGet.cast` recurses with the 
element type. The deferred path stores only the `value` param here, and 
`UnwrapVariantCastError` re-raises with `dataType = value.dataType` — the 
**field's full target type** (e.g. `ARRAY<INT>`). So the condition and 
offending value match, but the displayed type differs for composite targets 
(top-level scalar casts are unaffected). The new tests assert only the 
condition via `hasCastCondition`, so this isn't covered. Since the PR targets 
exception-semantics fidelity ("surface the same value that an eager raise would 
have"), is exact message parity with the eager path intended?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScanSuite.scala:
##########
@@ -49,13 +73,235 @@ trait PushVariantIntoScanSuiteBase extends 
SharedSparkSession {
     }
   }
 
+  test(s"Strict cast wraps with cast-error-deferred error") {
+    withTable("T") {
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+        sql("create table T (v variant) using parquet")
+        sql("select cast(v as int) as a, try_variant_get(v, '$.b', 'string') 
as b from T")
+          .queryExecution.optimizedPlan match {
+          case Project(projectList, l: LogicalRelation) =>
+            val output = l.output
+            val v = output(0)
+            // Strict cast should be wrapped with `UnwrapVariantCastError` 
over the sibling
+            // companion field whose `castErrorFor` metadata names the data 
field.
+            projectList(0) match {
+              case Alias(UnwrapVariantCastError(
+                  GetStructField(_, errOrd, _), GetStructField(_, 0, _)), "a") 
=>
+                assert(errOrd == 2, s"Expected companion ordinal 2, got 
$errOrd")
+              case other => fail(s"Unexpected projection 0: $other")
+            }
+            // try_variant_get is non-strict and should NOT be wrapped.
+            projectList(1) match {
+              case Alias(GetStructField(_, 1, _), "b") =>
+              case other => fail(s"Unexpected projection 1: $other")
+            }
+            val expected = StructType(Array(
+              field(0, IntegerType, "$", failOnError = true),
+              field(1, StringType, "$.b", failOnError = false),
+              StructField("2", StringType,
+                metadata = VariantMetadata.castErrorCompanionMetadata("0"))
+            ))
+            assert(v.dataType == expected, s"Got ${v.dataType}")
+          case other => fail(s"Unexpected plan: $other")
+        }
+      }
+    }
+  }
+
+  test(s"Cast-error companion is skipped for full-variant access") {
+    withTable("T") {
+      withSQLConf(
+        SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> "true") {
+        sql("create table T (v variant) using parquet")
+        // Selecting `v` alone produces only the full-variant request. 
cast-to-variant never
+        // fails, so no cast-error companion should be emitted.
+        sql("select v from T").queryExecution.optimizedPlan match {
+          case Project(_, l: LogicalRelation) =>
+            val v = l.output(0)
+            val expected = StructType(Array(
+              field(0, VariantType, "$", timeZone = "UTC")
+            ))
+            assert(v.dataType == expected, s"Got ${v.dataType}")
+          case other => fail(s"Unexpected plan: $other")
+        }
+      }
+    }
+  }
+
+  test(s"Reader defers strict-cast errors when cast-error companion is 
present") {
+    // Row 0: number 1 (LONG in variant) -> cast(v as int) succeeds.
+    // Row 1: string -> cast(v as int) would raise INVALID_VARIANT_CAST. With 
the deferral, the
+    //                  surrounding `if(schema_of_variant(v) = 'BIGINT', 
cast(v as int), null)`
+    //                  short-circuits to null before the error is observed.
+    withVariantParquetData("v variant",
+        "(parse_json('1'))",
+        "(parse_json('\"hello\"'))") {
+      val query =
+        "select if(schema_of_variant(v) = 'BIGINT', cast(v as int), null) as a 
from T"
+
+      // Without the deferral, the strict cast pushed into the scan raises at 
the failing row
+      // even though the `if` would have filtered it out.
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"false") {
+        val ex = intercept[Exception](sql(query).collect())
+        // The cast failure may surface directly or be wrapped in a task 
failure.
+        def hasCastCondition(t: Throwable): Boolean = t match {
+          case null => false
+          case s: org.apache.spark.SparkThrowable if s.getCondition == 
"INVALID_VARIANT_CAST" =>
+            true
+          case _ => hasCastCondition(t.getCause)
+        }
+        assert(hasCastCondition(ex), s"Expected INVALID_VARIANT_CAST, got $ex")
+      }
+
+      // With the deferral, the strict cast emits a cast-error companion and 
the `if`
+      // short-circuits before the failing row is consumed.
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+        val rows = sql(query).collect()
+        val values = rows.map(r => if (r.isNullAt(0)) null else 
r.getInt(0).asInstanceOf[Any])
+          .toSet
+        assert(values == Set(1, null), s"Got ${values.mkString(",")}")
+      }
+    }
+  }
+
+  test(s"Reader defers strict-cast errors for struct target") {
+    // Row 0: object with int field -> cast(v as struct<x int>) succeeds.
+    // Row 1: scalar -> cast(v as struct<x int>) would raise (wrong kind).
+    withVariantParquetData("v variant",
+        "(parse_json('{\"x\": 1}'))",
+        "(parse_json('\"hello\"'))") {
+      val query =
+        "select if(schema_of_variant(v) like 'OBJECT<%>', cast(v as struct<x: 
int>), null) as a " +
+          "from T"
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+        val rows = sql(query).collect()
+        val xs = rows.map { r =>
+          if (r.isNullAt(0)) null else 
r.getStruct(0).getInt(0).asInstanceOf[Any]
+        }.toSet
+        assert(xs == Set(1, null), s"Got ${xs.mkString(",")}")
+      }
+    }
+  }
+
+  test(s"Reader defers strict-cast errors for array target") {
+    // Row 0: array of ints -> cast(v as array<int>) succeeds.
+    // Row 1: scalar -> cast(v as array<int>) wrong-kind failure.
+    withVariantParquetData("v variant",
+        "(parse_json('[1, 2, 3]'))",
+        "(parse_json('\"hello\"'))") {
+      val query =
+        "select if(schema_of_variant(v) like 'ARRAY<%>', cast(v as 
array<int>), null) as a " +
+          "from T"
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+        val rows = sql(query).collect()
+        val arrs = rows.map { r =>
+          if (r.isNullAt(0)) null else r.getList[Int](0).toArray.toSeq
+        }.toSet
+        assert(arrs == Set(Seq(1, 2, 3), null), s"Got ${arrs.mkString(",")}")
+      }
+    }
+  }
+
+  test(s"Reader surfaces deferred error for array target with inner-element 
failure") {
+    // Row 0: heterogeneous array; cast(v as array<int>) fails on the inner 
string element.
+    // With deferred errors enabled, the failure must surface when the row is 
consumed by the
+    // outer expression -- i.e., the element-level companion buffer was 
correctly aggregated to
+    // the outer row.
+    withVariantParquetData("v variant",
+        "(parse_json('[1, \"abc\"]'))") {
+      withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+        val ex = intercept[Exception](sql("select cast(v as array<int>) from 
T").collect())
+        def hasCastCondition(t: Throwable): Boolean = t match {
+          case null => false
+          case s: org.apache.spark.SparkThrowable if s.getCondition == 
"INVALID_VARIANT_CAST" =>
+            true
+          case _ => hasCastCondition(t.getCause)
+        }
+        assert(hasCastCondition(ex), s"Expected INVALID_VARIANT_CAST, got $ex")
+      }
+    }
+  }
+
+  test(s"Reader surfaces deferred error for struct target with field cast 
failure") {
+    // Force the writer to shred `x` as int. The inner string `"abc"` lands in 
the unshredded
+    // `value` part, and `cast(v as struct<x: int>)` reads the int via the 
shredded path, which
+    // exercises `SparkShreddingUtils.getFieldsToExtract` / 
`assembleVariantStruct` with the new
+    // companion-field pairing.
+    withSQLConf(
+      SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> "true",
+      SQLConf.VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST.key -> "x int") {
+      withVariantParquetData("v variant",
+          "(parse_json('{\"x\": \"abc\"}'))") {
+        withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN_DEFER_CAST_ERROR.key -> 
"true") {
+          val ex =
+            intercept[Exception](sql("select cast(v as struct<x: int>) from 
T").collect())
+          def hasCastCondition(t: Throwable): Boolean = t match {
+            case null => false
+            case s: org.apache.spark.SparkThrowable
+                if s.getCondition == "INVALID_VARIANT_CAST" => true
+            case _ => hasCastCondition(t.getCause)
+          }
+          assert(hasCastCondition(ex), s"Expected INVALID_VARIANT_CAST, got 
$ex")
+        }
+      }
+    }
+  }
+
+  // Returns true iff `t` or any of its causes is an INVALID_VARIANT_CAST 
error.
+  private def hasCastCondition(t: Throwable): Boolean = t match {

Review Comment:
   `hasCastCondition` is also defined inline in three earlier tests (the 
`if(...)` scalar, array, and struct cases). Could those reuse this private 
helper to drop the duplicated copies?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to