the-other-tim-brown commented on code in PR #17904:
URL: https://github.com/apache/hudi/pull/17904#discussion_r2769922617


##########
hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/execution/datasources/SparkSchemaTransformUtils.scala:
##########
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.HoodieSchemaUtils
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{ArrayTransform, Attribute, 
AttributeReference, Cast, CreateNamedStruct, CreateStruct, Expression, 
GetStructField, LambdaFunction, Literal, MapEntries, MapFromEntries, 
NamedLambdaVariable, UnsafeProjection}
+import org.apache.spark.sql.types.{ArrayType, DataType, DateType, DecimalType, 
DoubleType, FloatType, IntegerType, LongType, MapType, StringType, StructField, 
StructType, TimestampNTZType}
+
+/**
+ * Format-agnostic utilities for Spark schema transformations including NULL 
padding
+ * and recursive type casting with workarounds for unsafe conversions.
+ *
+ * These utilities are used by file format readers that need to:
+ * - Pad missing columns with NULL literals (required for Lance)
+ * - Handle nested struct/array/map type conversions
+ * - Work around Spark unsafe cast issues (float->double, numeric->decimal)
+ *
+ * Note: The following functions were originally part of 
HoodieParquetFileFormatHelper
+ * and have been moved here to allow reuse across multiple file formats:
+ * - buildImplicitSchemaChangeInfo
+ * - isDataTypeEqual
+ * - generateUnsafeProjection
+ * - hasUnsupportedConversion
+ * - recursivelyCastExpressions
+ */
+object SparkSchemaTransformUtils {
+
+  /**
+   * Generate UnsafeProjection for type casting with special handling for 
unsupported conversions.
+   *
+   * @param fullSchema Complete schema including data and partition columns
+   * @param timeZoneId Session timezone for timestamp conversions
+   * @param typeChangeInfos Map of field index to (targetType, readerType) for 
fields needing casting
+   * @param requiredSchema Schema requested by the query (data columns only)
+   * @param partitionSchema Schema of partition columns
+   * @param schemaUtils Spark adapter schema utilities
+   * @return UnsafeProjection that applies type casting to rows
+   */
+  def generateUnsafeProjection(fullSchema: Seq[Attribute],
+                               timeZoneId: Option[String],
+                               typeChangeInfos: java.util.Map[Integer, 
org.apache.hudi.common.util.collection.Pair[DataType, DataType]],
+                               requiredSchema: StructType,
+                               partitionSchema: StructType,
+                               schemaUtils: HoodieSchemaUtils): 
UnsafeProjection = {
+    if (typeChangeInfos.isEmpty) {
+      GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+    } else {
+      // find type changed.
+      val newSchema = new StructType(requiredSchema.fields.zipWithIndex.map { 
case (f, i) =>
+        if (typeChangeInfos.containsKey(i)) {
+          StructField(f.name, typeChangeInfos.get(i).getRight, f.nullable, 
f.metadata)
+        } else f
+      })
+      val newFullSchema = schemaUtils.toAttributes(newSchema) ++ 
schemaUtils.toAttributes(partitionSchema)
+      val castSchema = newFullSchema.zipWithIndex.map { case (attr, i) =>
+        if (typeChangeInfos.containsKey(i)) {
+          val srcType = typeChangeInfos.get(i).getRight
+          val dstType = typeChangeInfos.get(i).getLeft
+          SparkSchemaTransformUtils.recursivelyCastExpressions(
+            attr, srcType, dstType, timeZoneId
+          )
+        } else attr
+      }
+      GenerateUnsafeProjection.generate(castSchema, newFullSchema)
+    }
+  }
+
+  /**
+   * Generate UnsafeProjection that pads missing columns with NULL literals.
+   *
+   * @param inputSchema Schema from file (fields actually present)
+   * @param targetSchema Target output schema (may have more fields than file)
+   * @return UnsafeProjection that pads missing columns with NULLs
+   */
+  def generateNullPaddingProjection(
+      inputSchema: StructType,
+      targetSchema: StructType
+  ): UnsafeProjection = {
+    val inputAttributes = inputSchema.fields.map(f =>
+      AttributeReference(f.name, f.dataType, f.nullable)())
+    val inputFieldMap = inputAttributes.map(a => a.name -> a).toMap
+
+    // Build expressions for all target fields, padding missing columns with 
NULL
+    val expressions = targetSchema.fields.map { field =>
+      inputFieldMap.get(field.name) match {
+        case Some(attr) =>
+          // Field exists in input - check if nested padding needed
+          if (needsNestedPadding(attr.dataType, field.dataType)) {
+            recursivelyPadExpression(attr, attr.dataType, field.dataType)
+          } else {
+            attr
+          }
+        case None =>
+          // Field missing from input, use NULL literal for padding
+          Literal(null, field.dataType)
+      }
+    }
+
+    GenerateUnsafeProjection.generate(expressions, inputAttributes)
+  }
+
+  /**
+   * Recursively pad nested struct/array/map fields with NULLs.
+   *
+   * @param expr Source expression
+   * @param srcType Source data type
+   * @param dstType Destination data type (may have additional nested fields)
+   * @return Expression with NULL padding for missing nested fields
+   */
+  private def recursivelyPadExpression(
+      expr: Expression,
+      srcType: DataType,
+      dstType: DataType
+  ): Expression = (srcType, dstType) match {
+    case (s: StructType, d: StructType) =>
+      val srcFieldMap = s.fields.zipWithIndex.map { case (f, i) => f.name -> 
(f, i) }.toMap
+      val structFields = d.fields.map { dstField =>
+        srcFieldMap.get(dstField.name) match {
+          case Some((srcField, srcIndex)) =>
+            val child = GetStructField(expr, srcIndex, Some(dstField.name))
+            recursivelyPadExpression(child, srcField.dataType, 
dstField.dataType)
+          case None =>
+            Literal(null, dstField.dataType)
+        }
+      }
+      CreateNamedStruct(d.fields.zip(structFields).flatMap {
+        case (f, c) => Seq(Literal(f.name), c)
+      })
+
+    case (ArrayType(sElementType, containsNull), ArrayType(dElementType, _))
+        if needsNestedPadding(sElementType, dElementType) =>
+      val lambdaVar = NamedLambdaVariable("element", sElementType, 
containsNull)
+      val body = recursivelyPadExpression(lambdaVar, sElementType, 
dElementType)
+      val func = LambdaFunction(body, Seq(lambdaVar))
+      ArrayTransform(expr, func)
+
+    case (MapType(sKeyType, sValType, vnull), MapType(dKeyType, dValType, _))
+        if needsNestedPadding(sKeyType, dKeyType) || 
needsNestedPadding(sValType, dValType) =>
+      val kv = NamedLambdaVariable("kv", new StructType()
+        .add("key", sKeyType, nullable = false)
+        .add("value", sValType, nullable = vnull), nullable = false)
+      val newKey = recursivelyPadExpression(GetStructField(kv, 0), sKeyType, 
dKeyType)
+      val newVal = recursivelyPadExpression(GetStructField(kv, 1), sValType, 
dValType)
+      val entry = CreateStruct(Seq(newKey, newVal))
+      val func = LambdaFunction(entry, Seq(kv))
+      val transformed = ArrayTransform(MapEntries(expr), func)
+      MapFromEntries(transformed)
+
+    case _ =>
+      // No padding needed, return expression as-is
+      expr
+  }
+
+  /**
+   * Recursively cast expressions with special handling for unsupported 
conversions.
+   *
+   * @param expr Source expression to cast
+   * @param srcType Source data type
+   * @param dstType Destination data type
+   * @param timeZoneId Session timezone for timestamp conversions
+   * @return Casted expression with workarounds for unsafe conversions
+   */
+  def recursivelyCastExpressions(

Review Comment:
   Can some of these methods be private?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to