This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new b9ac78be feat: Add Spark-compatible implementation of 
SchemaAdapterFactory (#1169)
b9ac78be is described below

commit b9ac78beffe8b71388ceb8d7e842fd6a07395829
Author: Andy Grove <[email protected]>
AuthorDate: Tue Dec 17 16:35:54 2024 -0700

    feat: Add Spark-compatible implementation of SchemaAdapterFactory (#1169)
    
    * Add Spark-compatible SchemaAdapterFactory implementation
    
    * remove prototype code
    
    * fix
    
    * refactor
    
    * implement more cast logic
    
    * implement more cast logic
    
    * add basic test
    
    * improve test
    
    * cleanup
    
    * fmt
    
    * add support for casting unsigned int to signed int
    
    * clippy
    
    * address feedback
    
    * fix test
---
 native/Cargo.lock                                  |  67 +++-
 native/Cargo.toml                                  |   2 +-
 native/core/src/parquet/util/test_common/mod.rs    |   3 +-
 native/spark-expr/Cargo.toml                       |   4 +-
 native/spark-expr/src/cast.rs                      | 353 +++++++++++++++----
 native/spark-expr/src/lib.rs                       |   5 +
 native/spark-expr/src/schema_adapter.rs            | 376 +++++++++++++++++++++
 .../src}/test_common/file_util.rs                  |   0
 .../util => spark-expr/src}/test_common/mod.rs     |   7 -
 9 files changed, 741 insertions(+), 76 deletions(-)

diff --git a/native/Cargo.lock b/native/Cargo.lock
index 7966bb80..538c40ee 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -436,7 +436,18 @@ checksum = 
"d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391"
 dependencies = [
  "alloc-no-stdlib",
  "alloc-stdlib",
- "brotli-decompressor",
+ "brotli-decompressor 2.5.1",
+]
+
+[[package]]
+name = "brotli"
+version = "7.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd"
+dependencies = [
+ "alloc-no-stdlib",
+ "alloc-stdlib",
+ "brotli-decompressor 4.0.1",
 ]
 
 [[package]]
@@ -449,6 +460,16 @@ dependencies = [
  "alloc-stdlib",
 ]
 
+[[package]]
+name = "brotli-decompressor"
+version = "4.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
+dependencies = [
+ "alloc-no-stdlib",
+ "alloc-stdlib",
+]
+
 [[package]]
 name = "bumpalo"
 version = "3.16.0"
@@ -842,6 +863,7 @@ dependencies = [
  "num_cpus",
  "object_store",
  "parking_lot",
+ "parquet",
  "paste",
  "pin-project-lite",
  "rand",
@@ -878,7 +900,7 @@ dependencies = [
  "arrow-schema",
  "assertables",
  "async-trait",
- "brotli",
+ "brotli 3.5.0",
  "bytes",
  "crc32fast",
  "criterion",
@@ -914,7 +936,7 @@ dependencies = [
  "tempfile",
  "thiserror",
  "tokio",
- "zstd",
+ "zstd 0.11.2+zstd.1.5.2",
 ]
 
 [[package]]
@@ -943,6 +965,7 @@ dependencies = [
  "datafusion-physical-expr",
  "futures",
  "num",
+ "parquet",
  "rand",
  "regex",
  "thiserror",
@@ -969,6 +992,7 @@ dependencies = [
  "libc",
  "num_cpus",
  "object_store",
+ "parquet",
  "paste",
  "sqlparser",
  "tokio",
@@ -2350,16 +2374,33 @@ source = 
"registry+https://github.com/rust-lang/crates.io-index";
 checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e"
 dependencies = [
  "ahash",
+ "arrow-array",
+ "arrow-buffer",
+ "arrow-cast",
+ "arrow-data",
+ "arrow-ipc",
+ "arrow-schema",
+ "arrow-select",
+ "base64",
+ "brotli 7.0.0",
  "bytes",
  "chrono",
+ "flate2",
+ "futures",
  "half",
  "hashbrown 0.14.5",
+ "lz4_flex",
  "num",
  "num-bigint",
+ "object_store",
  "paste",
  "seq-macro",
+ "snap",
  "thrift",
+ "tokio",
  "twox-hash 1.6.3",
+ "zstd 0.13.2",
+ "zstd-sys",
 ]
 
 [[package]]
@@ -3652,7 +3693,16 @@ version = "0.11.2+zstd.1.5.2"
 source = "registry+https://github.com/rust-lang/crates.io-index";
 checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
 dependencies = [
- "zstd-safe",
+ "zstd-safe 5.0.2+zstd.1.5.2",
+]
+
+[[package]]
+name = "zstd"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
+dependencies = [
+ "zstd-safe 7.2.1",
 ]
 
 [[package]]
@@ -3665,6 +3715,15 @@ dependencies = [
  "zstd-sys",
 ]
 
+[[package]]
+name = "zstd-safe"
+version = "7.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index";
+checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
+dependencies = [
+ "zstd-sys",
+]
+
 [[package]]
 name = "zstd-sys"
 version = "2.0.13+zstd.1.5.6"
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 4ac85479..bd46cf0c 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -39,8 +39,8 @@ arrow-buffer = { version = "53.2.0" }
 arrow-data = { version = "53.2.0" }
 arrow-schema = { version = "53.2.0" }
 parquet = { version = "53.2.0", default-features = false, features = 
["experimental"] }
-datafusion-common = { version = "43.0.0" }
 datafusion = { version = "43.0.0", default-features = false, features = 
["unicode_expressions", "crypto_expressions"] }
+datafusion-common = { version = "43.0.0" }
 datafusion-functions = { version = "43.0.0", features = ["crypto_expressions"] 
}
 datafusion-functions-nested = { version = "43.0.0", default-features = false }
 datafusion-expr = { version = "43.0.0", default-features = false }
diff --git a/native/core/src/parquet/util/test_common/mod.rs 
b/native/core/src/parquet/util/test_common/mod.rs
index e46d7322..d9254460 100644
--- a/native/core/src/parquet/util/test_common/mod.rs
+++ b/native/core/src/parquet/util/test_common/mod.rs
@@ -15,10 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
-pub mod file_util;
 pub mod page_util;
 pub mod rand_gen;
 
 pub use self::rand_gen::{random_bools, random_bytes, random_numbers, 
random_numbers_range};
 
-pub use self::file_util::{get_temp_file, get_temp_filename};
+pub use datafusion_comet_spark_expr::test_common::file_util::{get_temp_file, 
get_temp_filename};
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index d0bc2fd9..27367d83 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -33,7 +33,7 @@ arrow-buffer = { workspace = true }
 arrow-data = { workspace = true }
 arrow-schema = { workspace = true }
 chrono = { workspace = true }
-datafusion = { workspace = true }
+datafusion = { workspace = true, features = ["parquet"] }
 datafusion-common = { workspace = true }
 datafusion-expr = { workspace = true }
 datafusion-physical-expr = { workspace = true }
@@ -43,9 +43,11 @@ regex = { workspace = true }
 thiserror = { workspace = true }
 futures = { workspace = true }
 twox-hash = "2.0.0"
+rand = { workspace = true }
 
 [dev-dependencies]
 arrow-data = {workspace = true}
+parquet = { workspace = true, features = ["arrow"] }
 criterion = "0.5.1"
 rand = { workspace = true}
 tokio = { version = "1", features = ["rt-multi-thread"] }
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index f62d0220..d96bcbbd 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::timezone;
+use crate::utils::array_with_timezone;
+use crate::{EvalMode, SparkError, SparkResult};
 use arrow::{
     array::{
         cast::AsArray,
@@ -35,11 +38,18 @@ use arrow::{
 use arrow_array::builder::StringBuilder;
 use arrow_array::{DictionaryArray, StringArray, StructArray};
 use arrow_schema::{DataType, Field, Schema};
+use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
+use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
 use datafusion_common::{
     cast::as_generic_string_array, internal_err, Result as DataFusionResult, 
ScalarValue,
 };
 use datafusion_expr::ColumnarValue;
 use datafusion_physical_expr::PhysicalExpr;
+use num::{
+    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, 
Integer, Num,
+    ToPrimitive,
+};
+use regex::Regex;
 use std::str::FromStr;
 use std::{
     any::Any,
@@ -49,19 +59,6 @@ use std::{
     sync::Arc,
 };
 
-use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
-use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
-use num::{
-    cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, 
Integer, Num,
-    ToPrimitive,
-};
-use regex::Regex;
-
-use crate::timezone;
-use crate::utils::array_with_timezone;
-
-use crate::{EvalMode, SparkError, SparkResult};
-
 static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
 
 const MICROS_PER_SECOND: i64 = 1000000;
@@ -141,6 +138,240 @@ pub struct Cast {
     pub cast_options: SparkCastOptions,
 }
 
+/// Determine if Comet supports a cast, taking options such as EvalMode and 
Timezone into account.
+pub fn cast_supported(
+    from_type: &DataType,
+    to_type: &DataType,
+    options: &SparkCastOptions,
+) -> bool {
+    use DataType::*;
+
+    let from_type = if let Dictionary(_, dt) = from_type {
+        dt
+    } else {
+        from_type
+    };
+
+    let to_type = if let Dictionary(_, dt) = to_type {
+        dt
+    } else {
+        to_type
+    };
+
+    if from_type == to_type {
+        return true;
+    }
+
+    match (from_type, to_type) {
+        (Boolean, _) => can_cast_from_boolean(to_type, options),
+        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
+            if options.allow_cast_unsigned_ints =>
+        {
+            true
+        }
+        (Int8, _) => can_cast_from_byte(to_type, options),
+        (Int16, _) => can_cast_from_short(to_type, options),
+        (Int32, _) => can_cast_from_int(to_type, options),
+        (Int64, _) => can_cast_from_long(to_type, options),
+        (Float32, _) => can_cast_from_float(to_type, options),
+        (Float64, _) => can_cast_from_double(to_type, options),
+        (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
+        (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, 
options),
+        (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, 
options),
+        (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
+        (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
+        (Struct(from_fields), Struct(to_fields)) => from_fields
+            .iter()
+            .zip(to_fields.iter())
+            .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), 
options)),
+        _ => false,
+    }
+}
+
+fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> 
bool {
+    use DataType::*;
+    match to_type {
+        Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
+        Float32 | Float64 => {
+            // https://github.com/apache/datafusion-comet/issues/326
+            // Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'.
+            // Does not support ANSI mode.
+            options.allow_incompat
+        }
+        Decimal128(_, _) => {
+            // https://github.com/apache/datafusion-comet/issues/325
+            // Does not support inputs ending with 'd' or 'f'. Does not 
support 'inf'.
+            // Does not support ANSI mode. Returns 0.0 instead of null if 
input contains no digits
+
+            options.allow_incompat
+        }
+        Date32 | Date64 => {
+            // https://github.com/apache/datafusion-comet/issues/327
+            // Only supports years between 262143 BC and 262142 AD
+            options.allow_incompat
+        }
+        Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
+            // ANSI mode not supported
+            false
+        }
+        Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
+            // Cast will use UTC instead of $timeZoneId
+            options.allow_incompat
+        }
+        Timestamp(_, _) => {
+            // https://github.com/apache/datafusion-comet/issues/328
+            // Not all valid formats are supported
+            options.allow_incompat
+        }
+        _ => false,
+    }
+}
+
+fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> 
bool {
+    use DataType::*;
+    match from_type {
+        Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | 
Timestamp(_, _) => true,
+        Float32 | Float64 => {
+            // There can be differences in precision.
+            // For example, the input \"1.4E-45\" will produce 1.0E-45 " +
+            // instead of 1.4E-45"))
+            true
+        }
+        Decimal128(_, _) => {
+            // https://github.com/apache/datafusion-comet/issues/1068
+            // There can be formatting differences in some case due to Spark 
using
+            // scientific notation where Comet does not
+            true
+        }
+        Binary => {
+            // https://github.com/apache/datafusion-comet/issues/377
+            // Only works for binary data representing valid UTF-8 strings
+            options.allow_incompat
+        }
+        Struct(fields) => fields
+            .iter()
+            .all(|f| can_cast_to_string(f.data_type(), options)),
+        _ => false,
+    }
+}
+
+fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) 
-> bool {
+    use DataType::*;
+    match to_type {
+        Timestamp(_, _) | Date32 | Date64 | Utf8 => {
+            // incompatible
+            options.allow_incompat
+        }
+        _ => {
+            // unsupported
+            false
+        }
+    }
+}
+
+fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> 
bool {
+    use DataType::*;
+    match to_type {
+        Boolean | Int8 | Int16 => {
+            // https://github.com/apache/datafusion-comet/issues/352
+            // this seems like an edge case that isn't important for us to 
support
+            false
+        }
+        Int64 => {
+            // https://github.com/apache/datafusion-comet/issues/352
+            true
+        }
+        Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
+        _ => {
+            // unsupported
+            false
+        }
+    }
+}
+
+fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
+    use DataType::*;
+    matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
+}
+
+fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
+    use DataType::*;
+    matches!(
+        to_type,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _)
+    )
+}
+
+fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
+    use DataType::*;
+    matches!(
+        to_type,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | 
Decimal128(_, _)
+    )
+}
+
+fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
+    use DataType::*;
+    match to_type {
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => 
true,
+        Decimal128(_, _) => {
+            // incompatible: no overflow check
+            options.allow_incompat
+        }
+        _ => false,
+    }
+}
+
+fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
+    use DataType::*;
+    match to_type {
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+        Decimal128(_, _) => {
+            // incompatible: no overflow check
+            options.allow_incompat
+        }
+        _ => false,
+    }
+}
+
+fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
+    use DataType::*;
+    matches!(
+        to_type,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
+    )
+}
+
+fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
+    use DataType::*;
+    matches!(
+        to_type,
+        Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
+    )
+}
+
+fn can_cast_from_decimal(
+    p1: &u8,
+    _s1: &i8,
+    to_type: &DataType,
+    options: &SparkCastOptions,
+) -> bool {
+    use DataType::*;
+    match to_type {
+        Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+        Decimal128(p2, _) => {
+            if p2 < p1 {
+                // https://github.com/apache/datafusion/issues/13492
+                // Incompatible(Some("Casting to smaller precision is not 
supported"))
+                options.allow_incompat
+            } else {
+                true
+            }
+        }
+        _ => false,
+    }
+}
+
 macro_rules! cast_utf8_to_int {
     ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
         let len = $array.len();
@@ -560,6 +791,8 @@ pub struct SparkCastOptions {
     pub timezone: String,
     /// Allow casts that are supported but not guaranteed to be 100% compatible
     pub allow_incompat: bool,
+    /// Support casting unsigned ints to signed ints (used by Parquet 
SchemaAdapter)
+    pub allow_cast_unsigned_ints: bool,
 }
 
 impl SparkCastOptions {
@@ -568,6 +801,7 @@ impl SparkCastOptions {
             eval_mode,
             timezone: timezone.to_string(),
             allow_incompat,
+            allow_cast_unsigned_ints: false,
         }
     }
 
@@ -576,6 +810,7 @@ impl SparkCastOptions {
             eval_mode,
             timezone: "".to_string(),
             allow_incompat,
+            allow_cast_unsigned_ints: false,
         }
     }
 }
@@ -611,14 +846,14 @@ fn cast_array(
     to_type: &DataType,
     cast_options: &SparkCastOptions,
 ) -> DataFusionResult<ArrayRef> {
+    use DataType::*;
     let array = array_with_timezone(array, cast_options.timezone.clone(), 
Some(to_type))?;
     let from_type = array.data_type().clone();
 
     let array = match &from_type {
-        DataType::Dictionary(key_type, value_type)
-            if key_type.as_ref() == &DataType::Int32
-                && (value_type.as_ref() == &DataType::Utf8
-                    || value_type.as_ref() == &DataType::LargeUtf8) =>
+        Dictionary(key_type, value_type)
+            if key_type.as_ref() == &Int32
+                && (value_type.as_ref() == &Utf8 || value_type.as_ref() == 
&LargeUtf8) =>
         {
             let dict_array = array
                 .as_any()
@@ -631,7 +866,7 @@ fn cast_array(
             );
 
             let casted_result = match to_type {
-                DataType::Dictionary(_, _) => 
Arc::new(casted_dictionary.clone()),
+                Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
                 _ => take(casted_dictionary.values().as_ref(), 
dict_array.keys(), None)?,
             };
             return Ok(spark_cast_postprocess(casted_result, &from_type, 
to_type));
@@ -642,70 +877,66 @@ fn cast_array(
     let eval_mode = cast_options.eval_mode;
 
     let cast_result = match (from_type, to_type) {
-        (DataType::Utf8, DataType::Boolean) => 
spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
-        (DataType::LargeUtf8, DataType::Boolean) => {
-            spark_cast_utf8_to_boolean::<i64>(&array, eval_mode)
-        }
-        (DataType::Utf8, DataType::Timestamp(_, _)) => {
+        (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, 
eval_mode),
+        (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, 
eval_mode),
+        (Utf8, Timestamp(_, _)) => {
             cast_string_to_timestamp(&array, to_type, eval_mode, 
&cast_options.timezone)
         }
-        (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, 
to_type, eval_mode),
-        (DataType::Int64, DataType::Int32)
-        | (DataType::Int64, DataType::Int16)
-        | (DataType::Int64, DataType::Int8)
-        | (DataType::Int32, DataType::Int16)
-        | (DataType::Int32, DataType::Int8)
-        | (DataType::Int16, DataType::Int8)
+        (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
+        (Int64, Int32)
+        | (Int64, Int16)
+        | (Int64, Int8)
+        | (Int32, Int16)
+        | (Int32, Int8)
+        | (Int16, Int8)
             if eval_mode != EvalMode::Try =>
         {
             spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
         }
-        (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64) => {
+        (Utf8, Int8 | Int16 | Int32 | Int64) => {
             cast_string_to_int::<i32>(to_type, &array, eval_mode)
         }
-        (
-            DataType::LargeUtf8,
-            DataType::Int8 | DataType::Int16 | DataType::Int32 | 
DataType::Int64,
-        ) => cast_string_to_int::<i64>(to_type, &array, eval_mode),
-        (DataType::Float64, DataType::Utf8) => 
spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
-        (DataType::Float64, DataType::LargeUtf8) => {
-            spark_cast_float64_to_utf8::<i64>(&array, eval_mode)
-        }
-        (DataType::Float32, DataType::Utf8) => 
spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
-        (DataType::Float32, DataType::LargeUtf8) => {
-            spark_cast_float32_to_utf8::<i64>(&array, eval_mode)
-        }
-        (DataType::Float32, DataType::Decimal128(precision, scale)) => {
+        (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
+            cast_string_to_int::<i64>(to_type, &array, eval_mode)
+        }
+        (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array, 
eval_mode),
+        (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array, 
eval_mode),
+        (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array, 
eval_mode),
+        (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array, 
eval_mode),
+        (Float32, Decimal128(precision, scale)) => {
             cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
         }
-        (DataType::Float64, DataType::Decimal128(precision, scale)) => {
+        (Float64, Decimal128(precision, scale)) => {
             cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
         }
-        (DataType::Float32, DataType::Int8)
-        | (DataType::Float32, DataType::Int16)
-        | (DataType::Float32, DataType::Int32)
-        | (DataType::Float32, DataType::Int64)
-        | (DataType::Float64, DataType::Int8)
-        | (DataType::Float64, DataType::Int16)
-        | (DataType::Float64, DataType::Int32)
-        | (DataType::Float64, DataType::Int64)
-        | (DataType::Decimal128(_, _), DataType::Int8)
-        | (DataType::Decimal128(_, _), DataType::Int16)
-        | (DataType::Decimal128(_, _), DataType::Int32)
-        | (DataType::Decimal128(_, _), DataType::Int64)
+        (Float32, Int8)
+        | (Float32, Int16)
+        | (Float32, Int32)
+        | (Float32, Int64)
+        | (Float64, Int8)
+        | (Float64, Int16)
+        | (Float64, Int32)
+        | (Float64, Int64)
+        | (Decimal128(_, _), Int8)
+        | (Decimal128(_, _), Int16)
+        | (Decimal128(_, _), Int32)
+        | (Decimal128(_, _), Int64)
             if eval_mode != EvalMode::Try =>
         {
             spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, 
from_type, to_type)
         }
-        (DataType::Struct(_), DataType::Utf8) => {
-            Ok(casts_struct_to_string(array.as_struct(), cast_options)?)
-        }
-        (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct(
+        (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), 
cast_options)?),
+        (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
             array.as_struct(),
             from_type,
             to_type,
             cast_options,
         )?),
+        (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
+            if cast_options.allow_cast_unsigned_ints =>
+        {
+            Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
+        }
         _ if is_datafusion_spark_compatible(from_type, to_type, 
cast_options.allow_incompat) => {
             // use DataFusion cast only when we know that it is compatible 
with Spark
             Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 8a574805..f3587310 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -41,6 +41,9 @@ mod kernels;
 mod list;
 mod regexp;
 pub mod scalar_funcs;
+mod schema_adapter;
+pub use schema_adapter::SparkSchemaAdapterFactory;
+
 pub mod spark_hash;
 mod stddev;
 pub use stddev::Stddev;
@@ -51,6 +54,8 @@ mod negative;
 pub use negative::{create_negate_expr, NegativeExpr};
 mod normalize_nan;
 mod temporal;
+
+pub mod test_common;
 pub mod timezone;
 mod to_json;
 mod unbound;
diff --git a/native/spark-expr/src/schema_adapter.rs 
b/native/spark-expr/src/schema_adapter.rs
new file mode 100644
index 00000000..161ad6f1
--- /dev/null
+++ b/native/spark-expr/src/schema_adapter.rs
@@ -0,0 +1,376 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Custom schema adapter that uses Spark-compatible casts
+
+use crate::cast::cast_supported;
+use crate::{spark_cast, SparkCastOptions};
+use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions};
+use arrow_schema::{Schema, SchemaRef};
+use datafusion::datasource::schema_adapter::{SchemaAdapter, 
SchemaAdapterFactory, SchemaMapper};
+use datafusion_common::plan_err;
+use datafusion_expr::ColumnarValue;
+use std::sync::Arc;
+
+/// An implementation of DataFusion's `SchemaAdapterFactory` that uses a 
Spark-compatible
+/// `cast` implementation.
+#[derive(Clone, Debug)]
+pub struct SparkSchemaAdapterFactory {
+    /// Spark cast options
+    cast_options: SparkCastOptions,
+}
+
+impl SparkSchemaAdapterFactory {
+    pub fn new(options: SparkCastOptions) -> Self {
+        Self {
+            cast_options: options,
+        }
+    }
+}
+
+impl SchemaAdapterFactory for SparkSchemaAdapterFactory {
+    /// Create a new factory for mapping batches from a file schema to a table
+    /// schema.
+    ///
+    /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with
+    /// the same schema for both the projected table schema and the table
+    /// schema.
+    fn create(
+        &self,
+        required_schema: SchemaRef,
+        table_schema: SchemaRef,
+    ) -> Box<dyn SchemaAdapter> {
+        Box::new(SparkSchemaAdapter {
+            required_schema,
+            table_schema,
+            cast_options: self.cast_options.clone(),
+        })
+    }
+}
+
+/// This SchemaAdapter requires both the table schema and the projected table
+/// schema. See  [`SchemaMapping`] for more details
+#[derive(Clone, Debug)]
+pub struct SparkSchemaAdapter {
+    /// The schema for the table, projected to include only the fields being 
output (projected) by the
+    /// associated ParquetExec
+    required_schema: SchemaRef,
+    /// The entire table schema for the table we're using this to adapt.
+    ///
+    /// This is used to evaluate any filters pushed down into the scan
+    /// which may refer to columns that are not referred to anywhere
+    /// else in the plan.
+    table_schema: SchemaRef,
+    /// Spark cast options
+    cast_options: SparkCastOptions,
+}
+
+impl SchemaAdapter for SparkSchemaAdapter {
+    /// Map a column index in the table schema to a column index in a 
particular
+    /// file schema
+    ///
+    /// Panics if index is not in range for the table schema
+    fn map_column_index(&self, index: usize, file_schema: &Schema) -> 
Option<usize> {
+        let field = self.required_schema.field(index);
+        Some(file_schema.fields.find(field.name())?.0)
+    }
+
+    /// Creates a `SchemaMapping` for casting or mapping the columns from the
+    /// file schema to the table schema.
+    ///
+    /// If the provided `file_schema` contains columns of a different type to
+    /// the expected `table_schema`, the method will attempt to cast the array
+    /// data from the file schema to the table schema where possible.
+    ///
+    /// Returns a [`SchemaMapping`] that can be applied to the output batch
+    /// along with an ordered list of columns to project from the file
+    fn map_schema(
+        &self,
+        file_schema: &Schema,
+    ) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
+        let mut projection = Vec::with_capacity(file_schema.fields().len());
+        let mut field_mappings = vec![None; 
self.required_schema.fields().len()];
+
+        for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
+            if let Some((table_idx, table_field)) =
+                self.required_schema.fields().find(file_field.name())
+            {
+                if cast_supported(
+                    file_field.data_type(),
+                    table_field.data_type(),
+                    &self.cast_options,
+                ) {
+                    field_mappings[table_idx] = Some(projection.len());
+                    projection.push(file_idx);
+                } else {
+                    return plan_err!(
+                        "Cannot cast file schema field {} of type {:?} to 
required schema field of type {:?}",
+                        file_field.name(),
+                        file_field.data_type(),
+                        table_field.data_type()
+                    );
+                }
+            }
+        }
+
+        Ok((
+            Arc::new(SchemaMapping {
+                required_schema: Arc::<Schema>::clone(&self.required_schema),
+                field_mappings,
+                table_schema: Arc::<Schema>::clone(&self.table_schema),
+                cast_options: self.cast_options.clone(),
+            }),
+            projection,
+        ))
+    }
+}
+
+// TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast
+// instead of arrow cast - can we reduce the amount of code copied here and 
make
+// the DataFusion version more extensible?
+
+/// The SchemaMapping struct holds a mapping from the file schema to the table
+/// schema and any necessary type conversions.
+///
+/// Note, because `map_batch` and `map_partial_batch` functions have different
+/// needs, this struct holds two schemas:
+///
+/// 1. The projected **table** schema
+/// 2. The full table schema
+///
+/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which
+/// has the projected schema, since that's the schema which is supposed to come
+/// out of the execution of this query. Thus `map_batch` uses
+/// `projected_table_schema` as it can only operate on the projected fields.
+///
+/// [`map_partial_batch`]  is used to create a RecordBatch with a schema that
+/// can be used for Parquet predicate pushdown, meaning that it may contain
+/// fields which are not in the projected schema (as the fields that parquet
+/// pushdown filters operate can be completely distinct from the fields that 
are
+/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses
+/// `table_schema` to create the resulting RecordBatch (as it could be 
operating
+/// on any fields in the schema).
+///
+/// [`map_batch`]: Self::map_batch
+/// [`map_partial_batch`]: Self::map_partial_batch
+#[derive(Debug)]
+pub struct SchemaMapping {
+    /// The schema of the table. This is the expected schema after conversion
+    /// and it should match the schema of the query result.
+    required_schema: SchemaRef,
+    /// Mapping from field index in `projected_table_schema` to index in
+    /// projected file_schema.
+    ///
+    /// They are Options instead of just plain `usize`s because the table could
+    /// have fields that don't exist in the file.
+    field_mappings: Vec<Option<usize>>,
+    /// The entire table schema, as opposed to the projected_table_schema 
(which
+    /// only contains the columns that we are projecting out of this query).
+    /// This contains all fields in the table, regardless of if they will be
+    /// projected out or not.
+    table_schema: SchemaRef,
+    /// Spark cast options
+    cast_options: SparkCastOptions,
+}
+
+impl SchemaMapper for SchemaMapping {
+    /// Adapts a `RecordBatch` to match the `projected_table_schema` using the 
stored mapping and
+    /// conversions. The produced RecordBatch has a schema that contains only 
the projected
+    /// columns, so if one needs a RecordBatch with a schema that references 
columns which are not
+    /// in the projected, it would be better to use `map_partial_batch`
+    fn map_batch(&self, batch: RecordBatch) -> 
datafusion_common::Result<RecordBatch> {
+        let batch_rows = batch.num_rows();
+        let batch_cols = batch.columns().to_vec();
+
+        let cols = self
+            .required_schema
+            // go through each field in the projected schema
+            .fields()
+            .iter()
+            // and zip it with the index that maps fields from the projected 
table schema to the
+            // projected file schema in `batch`
+            .zip(&self.field_mappings)
+            // and for each one...
+            .map(|(field, file_idx)| {
+                file_idx.map_or_else(
+                    // If this field only exists in the table, and not in the 
file, then we know
+                    // that it's null, so just return that.
+                    || Ok(new_null_array(field.data_type(), batch_rows)),
+                    // However, if it does exist in both, then try to cast it 
to the correct output
+                    // type
+                    |batch_idx| {
+                        spark_cast(
+                            
ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])),
+                            field.data_type(),
+                            &self.cast_options,
+                        )?
+                        .into_array(batch_rows)
+                    },
+                )
+            })
+            .collect::<datafusion_common::Result<Vec<_>, _>>()?;
+
+        // Necessary to handle empty batches
+        let options = 
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+        let schema = Arc::<Schema>::clone(&self.required_schema);
+        let record_batch = RecordBatch::try_new_with_options(schema, cols, 
&options)?;
+        Ok(record_batch)
+    }
+
+    /// Adapts a [`RecordBatch`]'s schema into one that has all the correct 
output types and only
+    /// contains the fields that exist in both the file schema and table 
schema.
+    ///
+    /// Unlike `map_batch` this method also preserves the columns that
+    /// may not appear in the final output (`projected_table_schema`) but may
+    /// appear in push down predicates
+    fn map_partial_batch(&self, batch: RecordBatch) -> 
datafusion_common::Result<RecordBatch> {
+        let batch_cols = batch.columns().to_vec();
+        let schema = batch.schema();
+
+        // for each field in the batch's schema (which is based on a file, not 
a table)...
+        let (cols, fields) = schema
+            .fields()
+            .iter()
+            .zip(batch_cols.iter())
+            .flat_map(|(field, batch_col)| {
+                self.table_schema
+                    // try to get the same field from the table schema that we 
have stored in self
+                    .field_with_name(field.name())
+                    // and if we don't have it, that's fine, ignore it. This 
may occur when we've
+                    // created an external table whose fields are a subset of 
the fields in this
+                    // file, then tried to read data from the file into this 
table. If that is the
+                    // case here, it's fine to ignore because we don't care 
about this field
+                    // anyways
+                    .ok()
+                    // but if we do have it,
+                    .map(|table_field| {
+                        // try to cast it into the correct output type. we 
don't want to ignore this
+                        // error, though, so it's propagated.
+                        spark_cast(
+                            ColumnarValue::Array(Arc::clone(batch_col)),
+                            table_field.data_type(),
+                            &self.cast_options,
+                        )?
+                        .into_array(batch_col.len())
+                        // and if that works, return the field and column.
+                        .map(|new_col| (new_col, table_field.clone()))
+                    })
+            })
+            .collect::<Result<Vec<_>, _>>()?
+            .into_iter()
+            .unzip::<_, _, Vec<_>, Vec<_>>();
+
+        // Necessary to handle empty batches
+        let options = 
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+        let schema = Arc::new(Schema::new_with_metadata(fields, 
schema.metadata().clone()));
+        let record_batch = RecordBatch::try_new_with_options(schema, cols, 
&options)?;
+        Ok(record_batch)
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use crate::test_common::file_util::get_temp_filename;
+    use crate::{EvalMode, SparkCastOptions, SparkSchemaAdapterFactory};
+    use arrow::array::{Int32Array, StringArray};
+    use arrow::datatypes::{DataType, Field, Schema};
+    use arrow::record_batch::RecordBatch;
+    use arrow_array::UInt32Array;
+    use arrow_schema::SchemaRef;
+    use datafusion::datasource::listing::PartitionedFile;
+    use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
+    use datafusion::execution::object_store::ObjectStoreUrl;
+    use datafusion::execution::TaskContext;
+    use datafusion::physical_plan::ExecutionPlan;
+    use datafusion_common::DataFusionError;
+    use futures::StreamExt;
+    use parquet::arrow::ArrowWriter;
+    use std::fs::File;
+    use std::sync::Arc;
+
+    #[tokio::test]
+    async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> {
+        let file_schema = Arc::new(Schema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("name", DataType::Utf8, false),
+        ]));
+
+        let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as Arc<dyn 
arrow::array::Array>;
+        let names = Arc::new(StringArray::from(vec!["Alice", "Bob", 
"Charlie"]))
+            as Arc<dyn arrow::array::Array>;
+        let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids, 
names])?;
+
+        let required_schema = Arc::new(Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("name", DataType::Utf8, false),
+        ]));
+
+        let _ = roundtrip(&batch, required_schema).await?;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> {
+        let file_schema = Arc::new(Schema::new(vec![Field::new("id", 
DataType::UInt32, false)]));
+
+        let ids = Arc::new(UInt32Array::from(vec![1, 2, 3])) as Arc<dyn 
arrow::array::Array>;
+        let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids])?;
+
+        let required_schema = Arc::new(Schema::new(vec![Field::new("id", 
DataType::Int32, false)]));
+
+        let _ = roundtrip(&batch, required_schema).await?;
+
+        Ok(())
+    }
+
+    /// Create a Parquet file containing a single batch and then read the 
batch back using
+    /// the specified required_schema. This will cause the SchemaAdapter code 
to be used.
+    async fn roundtrip(
+        batch: &RecordBatch,
+        required_schema: SchemaRef,
+    ) -> Result<RecordBatch, DataFusionError> {
+        let filename = get_temp_filename();
+        let filename = 
filename.as_path().as_os_str().to_str().unwrap().to_string();
+        let file = File::create(&filename)?;
+        let mut writer = ArrowWriter::try_new(file, 
Arc::clone(&batch.schema()), None)?;
+        writer.write(batch)?;
+        writer.close()?;
+
+        let object_store_url = ObjectStoreUrl::local_filesystem();
+        let file_scan_config = FileScanConfig::new(object_store_url, 
required_schema)
+            .with_file_groups(vec![vec![PartitionedFile::from_path(
+                filename.to_string(),
+            )?]]);
+
+        let mut spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, 
"UTC", false);
+        spark_cast_options.allow_cast_unsigned_ints = true;
+
+        let parquet_exec = ParquetExec::builder(file_scan_config)
+            
.with_schema_adapter_factory(Arc::new(SparkSchemaAdapterFactory::new(
+                spark_cast_options,
+            )))
+            .build();
+
+        let mut stream = parquet_exec
+            .execute(0, Arc::new(TaskContext::default()))
+            .unwrap();
+        stream.next().await.unwrap()
+    }
+}
diff --git a/native/core/src/parquet/util/test_common/file_util.rs 
b/native/spark-expr/src/test_common/file_util.rs
similarity index 100%
rename from native/core/src/parquet/util/test_common/file_util.rs
rename to native/spark-expr/src/test_common/file_util.rs
diff --git a/native/core/src/parquet/util/test_common/mod.rs 
b/native/spark-expr/src/test_common/mod.rs
similarity index 80%
copy from native/core/src/parquet/util/test_common/mod.rs
copy to native/spark-expr/src/test_common/mod.rs
index e46d7322..efd25a4a 100644
--- a/native/core/src/parquet/util/test_common/mod.rs
+++ b/native/spark-expr/src/test_common/mod.rs
@@ -14,11 +14,4 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
-
 pub mod file_util;
-pub mod page_util;
-pub mod rand_gen;
-
-pub use self::rand_gen::{random_bools, random_bytes, random_numbers, 
random_numbers_range};
-
-pub use self::file_util::{get_temp_file, get_temp_filename};


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to