This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new b9ac78be feat: Add Spark-compatible implementation of
SchemaAdapterFactory (#1169)
b9ac78be is described below
commit b9ac78beffe8b71388ceb8d7e842fd6a07395829
Author: Andy Grove <[email protected]>
AuthorDate: Tue Dec 17 16:35:54 2024 -0700
feat: Add Spark-compatible implementation of SchemaAdapterFactory (#1169)
* Add Spark-compatible SchemaAdapterFactory implementation
* remove prototype code
* fix
* refactor
* implement more cast logic
* implement more cast logic
* add basic test
* improve test
* cleanup
* fmt
* add support for casting unsigned int to signed int
* clippy
* address feedback
* fix test
---
native/Cargo.lock | 67 +++-
native/Cargo.toml | 2 +-
native/core/src/parquet/util/test_common/mod.rs | 3 +-
native/spark-expr/Cargo.toml | 4 +-
native/spark-expr/src/cast.rs | 353 +++++++++++++++----
native/spark-expr/src/lib.rs | 5 +
native/spark-expr/src/schema_adapter.rs | 376 +++++++++++++++++++++
.../src}/test_common/file_util.rs | 0
.../util => spark-expr/src}/test_common/mod.rs | 7 -
9 files changed, 741 insertions(+), 76 deletions(-)
diff --git a/native/Cargo.lock b/native/Cargo.lock
index 7966bb80..538c40ee 100644
--- a/native/Cargo.lock
+++ b/native/Cargo.lock
@@ -436,7 +436,18 @@ checksum =
"d640d25bc63c50fb1f0b545ffd80207d2e10a4c965530809b40ba3386825c391"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
- "brotli-decompressor",
+ "brotli-decompressor 2.5.1",
+]
+
+[[package]]
+name = "brotli"
+version = "7.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cc97b8f16f944bba54f0433f07e30be199b6dc2bd25937444bbad560bcea29bd"
+dependencies = [
+ "alloc-no-stdlib",
+ "alloc-stdlib",
+ "brotli-decompressor 4.0.1",
]
[[package]]
@@ -449,6 +460,16 @@ dependencies = [
"alloc-stdlib",
]
+[[package]]
+name = "brotli-decompressor"
+version = "4.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
+dependencies = [
+ "alloc-no-stdlib",
+ "alloc-stdlib",
+]
+
[[package]]
name = "bumpalo"
version = "3.16.0"
@@ -842,6 +863,7 @@ dependencies = [
"num_cpus",
"object_store",
"parking_lot",
+ "parquet",
"paste",
"pin-project-lite",
"rand",
@@ -878,7 +900,7 @@ dependencies = [
"arrow-schema",
"assertables",
"async-trait",
- "brotli",
+ "brotli 3.5.0",
"bytes",
"crc32fast",
"criterion",
@@ -914,7 +936,7 @@ dependencies = [
"tempfile",
"thiserror",
"tokio",
- "zstd",
+ "zstd 0.11.2+zstd.1.5.2",
]
[[package]]
@@ -943,6 +965,7 @@ dependencies = [
"datafusion-physical-expr",
"futures",
"num",
+ "parquet",
"rand",
"regex",
"thiserror",
@@ -969,6 +992,7 @@ dependencies = [
"libc",
"num_cpus",
"object_store",
+ "parquet",
"paste",
"sqlparser",
"tokio",
@@ -2350,16 +2374,33 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "dea02606ba6f5e856561d8d507dba8bac060aefca2a6c0f1aa1d361fed91ff3e"
dependencies = [
"ahash",
+ "arrow-array",
+ "arrow-buffer",
+ "arrow-cast",
+ "arrow-data",
+ "arrow-ipc",
+ "arrow-schema",
+ "arrow-select",
+ "base64",
+ "brotli 7.0.0",
"bytes",
"chrono",
+ "flate2",
+ "futures",
"half",
"hashbrown 0.14.5",
+ "lz4_flex",
"num",
"num-bigint",
+ "object_store",
"paste",
"seq-macro",
+ "snap",
"thrift",
+ "tokio",
"twox-hash 1.6.3",
+ "zstd 0.13.2",
+ "zstd-sys",
]
[[package]]
@@ -3652,7 +3693,16 @@ version = "0.11.2+zstd.1.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
dependencies = [
- "zstd-safe",
+ "zstd-safe 5.0.2+zstd.1.5.2",
+]
+
+[[package]]
+name = "zstd"
+version = "0.13.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
+dependencies = [
+ "zstd-safe 7.2.1",
]
[[package]]
@@ -3665,6 +3715,15 @@ dependencies = [
"zstd-sys",
]
+[[package]]
+name = "zstd-safe"
+version = "7.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
+dependencies = [
+ "zstd-sys",
+]
+
[[package]]
name = "zstd-sys"
version = "2.0.13+zstd.1.5.6"
diff --git a/native/Cargo.toml b/native/Cargo.toml
index 4ac85479..bd46cf0c 100644
--- a/native/Cargo.toml
+++ b/native/Cargo.toml
@@ -39,8 +39,8 @@ arrow-buffer = { version = "53.2.0" }
arrow-data = { version = "53.2.0" }
arrow-schema = { version = "53.2.0" }
parquet = { version = "53.2.0", default-features = false, features =
["experimental"] }
-datafusion-common = { version = "43.0.0" }
datafusion = { version = "43.0.0", default-features = false, features =
["unicode_expressions", "crypto_expressions"] }
+datafusion-common = { version = "43.0.0" }
datafusion-functions = { version = "43.0.0", features = ["crypto_expressions"]
}
datafusion-functions-nested = { version = "43.0.0", default-features = false }
datafusion-expr = { version = "43.0.0", default-features = false }
diff --git a/native/core/src/parquet/util/test_common/mod.rs
b/native/core/src/parquet/util/test_common/mod.rs
index e46d7322..d9254460 100644
--- a/native/core/src/parquet/util/test_common/mod.rs
+++ b/native/core/src/parquet/util/test_common/mod.rs
@@ -15,10 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-pub mod file_util;
pub mod page_util;
pub mod rand_gen;
pub use self::rand_gen::{random_bools, random_bytes, random_numbers,
random_numbers_range};
-pub use self::file_util::{get_temp_file, get_temp_filename};
+pub use datafusion_comet_spark_expr::test_common::file_util::{get_temp_file,
get_temp_filename};
diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml
index d0bc2fd9..27367d83 100644
--- a/native/spark-expr/Cargo.toml
+++ b/native/spark-expr/Cargo.toml
@@ -33,7 +33,7 @@ arrow-buffer = { workspace = true }
arrow-data = { workspace = true }
arrow-schema = { workspace = true }
chrono = { workspace = true }
-datafusion = { workspace = true }
+datafusion = { workspace = true, features = ["parquet"] }
datafusion-common = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-physical-expr = { workspace = true }
@@ -43,9 +43,11 @@ regex = { workspace = true }
thiserror = { workspace = true }
futures = { workspace = true }
twox-hash = "2.0.0"
+rand = { workspace = true }
[dev-dependencies]
arrow-data = {workspace = true}
+parquet = { workspace = true, features = ["arrow"] }
criterion = "0.5.1"
rand = { workspace = true}
tokio = { version = "1", features = ["rt-multi-thread"] }
diff --git a/native/spark-expr/src/cast.rs b/native/spark-expr/src/cast.rs
index f62d0220..d96bcbbd 100644
--- a/native/spark-expr/src/cast.rs
+++ b/native/spark-expr/src/cast.rs
@@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.
+use crate::timezone;
+use crate::utils::array_with_timezone;
+use crate::{EvalMode, SparkError, SparkResult};
use arrow::{
array::{
cast::AsArray,
@@ -35,11 +38,18 @@ use arrow::{
use arrow_array::builder::StringBuilder;
use arrow_array::{DictionaryArray, StringArray, StructArray};
use arrow_schema::{DataType, Field, Schema};
+use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
+use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
use datafusion_common::{
cast::as_generic_string_array, internal_err, Result as DataFusionResult,
ScalarValue,
};
use datafusion_expr::ColumnarValue;
use datafusion_physical_expr::PhysicalExpr;
+use num::{
+ cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub,
Integer, Num,
+ ToPrimitive,
+};
+use regex::Regex;
use std::str::FromStr;
use std::{
any::Any,
@@ -49,19 +59,6 @@ use std::{
sync::Arc,
};
-use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike};
-use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
-use num::{
- cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub,
Integer, Num,
- ToPrimitive,
-};
-use regex::Regex;
-
-use crate::timezone;
-use crate::utils::array_with_timezone;
-
-use crate::{EvalMode, SparkError, SparkResult};
-
static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");
const MICROS_PER_SECOND: i64 = 1000000;
@@ -141,6 +138,240 @@ pub struct Cast {
pub cast_options: SparkCastOptions,
}
+/// Determine if Comet supports a cast, taking options such as EvalMode and
Timezone into account.
+pub fn cast_supported(
+ from_type: &DataType,
+ to_type: &DataType,
+ options: &SparkCastOptions,
+) -> bool {
+ use DataType::*;
+
+ let from_type = if let Dictionary(_, dt) = from_type {
+ dt
+ } else {
+ from_type
+ };
+
+ let to_type = if let Dictionary(_, dt) = to_type {
+ dt
+ } else {
+ to_type
+ };
+
+ if from_type == to_type {
+ return true;
+ }
+
+ match (from_type, to_type) {
+ (Boolean, _) => can_cast_from_boolean(to_type, options),
+ (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
+ if options.allow_cast_unsigned_ints =>
+ {
+ true
+ }
+ (Int8, _) => can_cast_from_byte(to_type, options),
+ (Int16, _) => can_cast_from_short(to_type, options),
+ (Int32, _) => can_cast_from_int(to_type, options),
+ (Int64, _) => can_cast_from_long(to_type, options),
+ (Float32, _) => can_cast_from_float(to_type, options),
+ (Float64, _) => can_cast_from_double(to_type, options),
+ (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options),
+ (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type,
options),
+ (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type,
options),
+ (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options),
+ (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options),
+ (Struct(from_fields), Struct(to_fields)) => from_fields
+ .iter()
+ .zip(to_fields.iter())
+ .all(|(a, b)| cast_supported(a.data_type(), b.data_type(),
options)),
+ _ => false,
+ }
+}
+
+fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) ->
bool {
+ use DataType::*;
+ match to_type {
+ Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
+ Float32 | Float64 => {
+ // https://github.com/apache/datafusion-comet/issues/326
+ // Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'.
+ // Does not support ANSI mode.
+ options.allow_incompat
+ }
+ Decimal128(_, _) => {
+ // https://github.com/apache/datafusion-comet/issues/325
+ // Does not support inputs ending with 'd' or 'f'. Does not
support 'inf'.
+ // Does not support ANSI mode. Returns 0.0 instead of null if
input contains no digits
+
+ options.allow_incompat
+ }
+ Date32 | Date64 => {
+ // https://github.com/apache/datafusion-comet/issues/327
+ // Only supports years between 262143 BC and 262142 AD
+ options.allow_incompat
+ }
+ Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => {
+ // ANSI mode not supported
+ false
+ }
+ Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => {
+ // Cast will use UTC instead of $timeZoneId
+ options.allow_incompat
+ }
+ Timestamp(_, _) => {
+ // https://github.com/apache/datafusion-comet/issues/328
+ // Not all valid formats are supported
+ options.allow_incompat
+ }
+ _ => false,
+ }
+}
+
+fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) ->
bool {
+ use DataType::*;
+ match from_type {
+ Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 |
Timestamp(_, _) => true,
+ Float32 | Float64 => {
+ // There can be differences in precision.
+ // For example, the input \"1.4E-45\" will produce 1.0E-45 " +
+ // instead of 1.4E-45"))
+ true
+ }
+ Decimal128(_, _) => {
+ // https://github.com/apache/datafusion-comet/issues/1068
+ // There can be formatting differences in some case due to Spark
using
+ // scientific notation where Comet does not
+ true
+ }
+ Binary => {
+ // https://github.com/apache/datafusion-comet/issues/377
+ // Only works for binary data representing valid UTF-8 strings
+ options.allow_incompat
+ }
+ Struct(fields) => fields
+ .iter()
+ .all(|f| can_cast_to_string(f.data_type(), options)),
+ _ => false,
+ }
+}
+
+fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions)
-> bool {
+ use DataType::*;
+ match to_type {
+ Timestamp(_, _) | Date32 | Date64 | Utf8 => {
+ // incompatible
+ options.allow_incompat
+ }
+ _ => {
+ // unsupported
+ false
+ }
+ }
+}
+
+fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) ->
bool {
+ use DataType::*;
+ match to_type {
+ Boolean | Int8 | Int16 => {
+ // https://github.com/apache/datafusion-comet/issues/352
+ // this seems like an edge case that isn't important for us to
support
+ false
+ }
+ Int64 => {
+ // https://github.com/apache/datafusion-comet/issues/352
+ true
+ }
+ Date32 | Date64 | Utf8 | Decimal128(_, _) => true,
+ _ => {
+ // unsupported
+ false
+ }
+ }
+}
+
+fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool {
+ use DataType::*;
+ matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64)
+}
+
+fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool {
+ use DataType::*;
+ matches!(
+ to_type,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _)
+ )
+}
+
+fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool {
+ use DataType::*;
+ matches!(
+ to_type,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 |
Decimal128(_, _)
+ )
+}
+
+fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool {
+ use DataType::*;
+ match to_type {
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 =>
true,
+ Decimal128(_, _) => {
+ // incompatible: no overflow check
+ options.allow_incompat
+ }
+ _ => false,
+ }
+}
+
+fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool {
+ use DataType::*;
+ match to_type {
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+ Decimal128(_, _) => {
+ // incompatible: no overflow check
+ options.allow_incompat
+ }
+ _ => false,
+ }
+}
+
+fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool {
+ use DataType::*;
+ matches!(
+ to_type,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _)
+ )
+}
+
+fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool {
+ use DataType::*;
+ matches!(
+ to_type,
+ Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _)
+ )
+}
+
+fn can_cast_from_decimal(
+ p1: &u8,
+ _s1: &i8,
+ to_type: &DataType,
+ options: &SparkCastOptions,
+) -> bool {
+ use DataType::*;
+ match to_type {
+ Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true,
+ Decimal128(p2, _) => {
+ if p2 < p1 {
+ // https://github.com/apache/datafusion/issues/13492
+ // Incompatible(Some("Casting to smaller precision is not
supported"))
+ options.allow_incompat
+ } else {
+ true
+ }
+ }
+ _ => false,
+ }
+}
+
macro_rules! cast_utf8_to_int {
($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let len = $array.len();
@@ -560,6 +791,8 @@ pub struct SparkCastOptions {
pub timezone: String,
/// Allow casts that are supported but not guaranteed to be 100% compatible
pub allow_incompat: bool,
+ /// Support casting unsigned ints to signed ints (used by Parquet
SchemaAdapter)
+ pub allow_cast_unsigned_ints: bool,
}
impl SparkCastOptions {
@@ -568,6 +801,7 @@ impl SparkCastOptions {
eval_mode,
timezone: timezone.to_string(),
allow_incompat,
+ allow_cast_unsigned_ints: false,
}
}
@@ -576,6 +810,7 @@ impl SparkCastOptions {
eval_mode,
timezone: "".to_string(),
allow_incompat,
+ allow_cast_unsigned_ints: false,
}
}
}
@@ -611,14 +846,14 @@ fn cast_array(
to_type: &DataType,
cast_options: &SparkCastOptions,
) -> DataFusionResult<ArrayRef> {
+ use DataType::*;
let array = array_with_timezone(array, cast_options.timezone.clone(),
Some(to_type))?;
let from_type = array.data_type().clone();
let array = match &from_type {
- DataType::Dictionary(key_type, value_type)
- if key_type.as_ref() == &DataType::Int32
- && (value_type.as_ref() == &DataType::Utf8
- || value_type.as_ref() == &DataType::LargeUtf8) =>
+ Dictionary(key_type, value_type)
+ if key_type.as_ref() == &Int32
+ && (value_type.as_ref() == &Utf8 || value_type.as_ref() ==
&LargeUtf8) =>
{
let dict_array = array
.as_any()
@@ -631,7 +866,7 @@ fn cast_array(
);
let casted_result = match to_type {
- DataType::Dictionary(_, _) =>
Arc::new(casted_dictionary.clone()),
+ Dictionary(_, _) => Arc::new(casted_dictionary.clone()),
_ => take(casted_dictionary.values().as_ref(),
dict_array.keys(), None)?,
};
return Ok(spark_cast_postprocess(casted_result, &from_type,
to_type));
@@ -642,70 +877,66 @@ fn cast_array(
let eval_mode = cast_options.eval_mode;
let cast_result = match (from_type, to_type) {
- (DataType::Utf8, DataType::Boolean) =>
spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
- (DataType::LargeUtf8, DataType::Boolean) => {
- spark_cast_utf8_to_boolean::<i64>(&array, eval_mode)
- }
- (DataType::Utf8, DataType::Timestamp(_, _)) => {
+ (Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array,
eval_mode),
+ (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array,
eval_mode),
+ (Utf8, Timestamp(_, _)) => {
cast_string_to_timestamp(&array, to_type, eval_mode,
&cast_options.timezone)
}
- (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array,
to_type, eval_mode),
- (DataType::Int64, DataType::Int32)
- | (DataType::Int64, DataType::Int16)
- | (DataType::Int64, DataType::Int8)
- | (DataType::Int32, DataType::Int16)
- | (DataType::Int32, DataType::Int8)
- | (DataType::Int16, DataType::Int8)
+ (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
+ (Int64, Int32)
+ | (Int64, Int16)
+ | (Int64, Int8)
+ | (Int32, Int16)
+ | (Int32, Int8)
+ | (Int16, Int8)
if eval_mode != EvalMode::Try =>
{
spark_cast_int_to_int(&array, eval_mode, from_type, to_type)
}
- (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64) => {
+ (Utf8, Int8 | Int16 | Int32 | Int64) => {
cast_string_to_int::<i32>(to_type, &array, eval_mode)
}
- (
- DataType::LargeUtf8,
- DataType::Int8 | DataType::Int16 | DataType::Int32 |
DataType::Int64,
- ) => cast_string_to_int::<i64>(to_type, &array, eval_mode),
- (DataType::Float64, DataType::Utf8) =>
spark_cast_float64_to_utf8::<i32>(&array, eval_mode),
- (DataType::Float64, DataType::LargeUtf8) => {
- spark_cast_float64_to_utf8::<i64>(&array, eval_mode)
- }
- (DataType::Float32, DataType::Utf8) =>
spark_cast_float32_to_utf8::<i32>(&array, eval_mode),
- (DataType::Float32, DataType::LargeUtf8) => {
- spark_cast_float32_to_utf8::<i64>(&array, eval_mode)
- }
- (DataType::Float32, DataType::Decimal128(precision, scale)) => {
+ (LargeUtf8, Int8 | Int16 | Int32 | Int64) => {
+ cast_string_to_int::<i64>(to_type, &array, eval_mode)
+ }
+ (Float64, Utf8) => spark_cast_float64_to_utf8::<i32>(&array,
eval_mode),
+ (Float64, LargeUtf8) => spark_cast_float64_to_utf8::<i64>(&array,
eval_mode),
+ (Float32, Utf8) => spark_cast_float32_to_utf8::<i32>(&array,
eval_mode),
+ (Float32, LargeUtf8) => spark_cast_float32_to_utf8::<i64>(&array,
eval_mode),
+ (Float32, Decimal128(precision, scale)) => {
cast_float32_to_decimal128(&array, *precision, *scale, eval_mode)
}
- (DataType::Float64, DataType::Decimal128(precision, scale)) => {
+ (Float64, Decimal128(precision, scale)) => {
cast_float64_to_decimal128(&array, *precision, *scale, eval_mode)
}
- (DataType::Float32, DataType::Int8)
- | (DataType::Float32, DataType::Int16)
- | (DataType::Float32, DataType::Int32)
- | (DataType::Float32, DataType::Int64)
- | (DataType::Float64, DataType::Int8)
- | (DataType::Float64, DataType::Int16)
- | (DataType::Float64, DataType::Int32)
- | (DataType::Float64, DataType::Int64)
- | (DataType::Decimal128(_, _), DataType::Int8)
- | (DataType::Decimal128(_, _), DataType::Int16)
- | (DataType::Decimal128(_, _), DataType::Int32)
- | (DataType::Decimal128(_, _), DataType::Int64)
+ (Float32, Int8)
+ | (Float32, Int16)
+ | (Float32, Int32)
+ | (Float32, Int64)
+ | (Float64, Int8)
+ | (Float64, Int16)
+ | (Float64, Int32)
+ | (Float64, Int64)
+ | (Decimal128(_, _), Int8)
+ | (Decimal128(_, _), Int16)
+ | (Decimal128(_, _), Int32)
+ | (Decimal128(_, _), Int64)
if eval_mode != EvalMode::Try =>
{
spark_cast_nonintegral_numeric_to_integral(&array, eval_mode,
from_type, to_type)
}
- (DataType::Struct(_), DataType::Utf8) => {
- Ok(casts_struct_to_string(array.as_struct(), cast_options)?)
- }
- (DataType::Struct(_), DataType::Struct(_)) => Ok(cast_struct_to_struct(
+ (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(),
cast_options)?),
+ (Struct(_), Struct(_)) => Ok(cast_struct_to_struct(
array.as_struct(),
from_type,
to_type,
cast_options,
)?),
+ (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
+ if cast_options.allow_cast_unsigned_ints =>
+ {
+ Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
+ }
_ if is_datafusion_spark_compatible(from_type, to_type,
cast_options.allow_incompat) => {
// use DataFusion cast only when we know that it is compatible
with Spark
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs
index 8a574805..f3587310 100644
--- a/native/spark-expr/src/lib.rs
+++ b/native/spark-expr/src/lib.rs
@@ -41,6 +41,9 @@ mod kernels;
mod list;
mod regexp;
pub mod scalar_funcs;
+mod schema_adapter;
+pub use schema_adapter::SparkSchemaAdapterFactory;
+
pub mod spark_hash;
mod stddev;
pub use stddev::Stddev;
@@ -51,6 +54,8 @@ mod negative;
pub use negative::{create_negate_expr, NegativeExpr};
mod normalize_nan;
mod temporal;
+
+pub mod test_common;
pub mod timezone;
mod to_json;
mod unbound;
diff --git a/native/spark-expr/src/schema_adapter.rs
b/native/spark-expr/src/schema_adapter.rs
new file mode 100644
index 00000000..161ad6f1
--- /dev/null
+++ b/native/spark-expr/src/schema_adapter.rs
@@ -0,0 +1,376 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Custom schema adapter that uses Spark-compatible casts
+
+use crate::cast::cast_supported;
+use crate::{spark_cast, SparkCastOptions};
+use arrow_array::{new_null_array, Array, RecordBatch, RecordBatchOptions};
+use arrow_schema::{Schema, SchemaRef};
+use datafusion::datasource::schema_adapter::{SchemaAdapter,
SchemaAdapterFactory, SchemaMapper};
+use datafusion_common::plan_err;
+use datafusion_expr::ColumnarValue;
+use std::sync::Arc;
+
+/// An implementation of DataFusion's `SchemaAdapterFactory` that uses a
Spark-compatible
+/// `cast` implementation.
+#[derive(Clone, Debug)]
+pub struct SparkSchemaAdapterFactory {
+ /// Spark cast options
+ cast_options: SparkCastOptions,
+}
+
+impl SparkSchemaAdapterFactory {
+ pub fn new(options: SparkCastOptions) -> Self {
+ Self {
+ cast_options: options,
+ }
+ }
+}
+
+impl SchemaAdapterFactory for SparkSchemaAdapterFactory {
+ /// Create a new factory for mapping batches from a file schema to a table
+ /// schema.
+ ///
+ /// This is a convenience for [`DefaultSchemaAdapterFactory::create`] with
+ /// the same schema for both the projected table schema and the table
+ /// schema.
+ fn create(
+ &self,
+ required_schema: SchemaRef,
+ table_schema: SchemaRef,
+ ) -> Box<dyn SchemaAdapter> {
+ Box::new(SparkSchemaAdapter {
+ required_schema,
+ table_schema,
+ cast_options: self.cast_options.clone(),
+ })
+ }
+}
+
+/// This SchemaAdapter requires both the table schema and the projected table
+/// schema. See [`SchemaMapping`] for more details
+#[derive(Clone, Debug)]
+pub struct SparkSchemaAdapter {
+ /// The schema for the table, projected to include only the fields being
output (projected) by the
+ /// associated ParquetExec
+ required_schema: SchemaRef,
+ /// The entire table schema for the table we're using this to adapt.
+ ///
+ /// This is used to evaluate any filters pushed down into the scan
+ /// which may refer to columns that are not referred to anywhere
+ /// else in the plan.
+ table_schema: SchemaRef,
+ /// Spark cast options
+ cast_options: SparkCastOptions,
+}
+
+impl SchemaAdapter for SparkSchemaAdapter {
+ /// Map a column index in the table schema to a column index in a
particular
+ /// file schema
+ ///
+ /// Panics if index is not in range for the table schema
+ fn map_column_index(&self, index: usize, file_schema: &Schema) ->
Option<usize> {
+ let field = self.required_schema.field(index);
+ Some(file_schema.fields.find(field.name())?.0)
+ }
+
+ /// Creates a `SchemaMapping` for casting or mapping the columns from the
+ /// file schema to the table schema.
+ ///
+ /// If the provided `file_schema` contains columns of a different type to
+ /// the expected `table_schema`, the method will attempt to cast the array
+ /// data from the file schema to the table schema where possible.
+ ///
+ /// Returns a [`SchemaMapping`] that can be applied to the output batch
+ /// along with an ordered list of columns to project from the file
+ fn map_schema(
+ &self,
+ file_schema: &Schema,
+ ) -> datafusion_common::Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
+ let mut projection = Vec::with_capacity(file_schema.fields().len());
+ let mut field_mappings = vec![None;
self.required_schema.fields().len()];
+
+ for (file_idx, file_field) in file_schema.fields.iter().enumerate() {
+ if let Some((table_idx, table_field)) =
+ self.required_schema.fields().find(file_field.name())
+ {
+ if cast_supported(
+ file_field.data_type(),
+ table_field.data_type(),
+ &self.cast_options,
+ ) {
+ field_mappings[table_idx] = Some(projection.len());
+ projection.push(file_idx);
+ } else {
+ return plan_err!(
+ "Cannot cast file schema field {} of type {:?} to
required schema field of type {:?}",
+ file_field.name(),
+ file_field.data_type(),
+ table_field.data_type()
+ );
+ }
+ }
+ }
+
+ Ok((
+ Arc::new(SchemaMapping {
+ required_schema: Arc::<Schema>::clone(&self.required_schema),
+ field_mappings,
+ table_schema: Arc::<Schema>::clone(&self.table_schema),
+ cast_options: self.cast_options.clone(),
+ }),
+ projection,
+ ))
+ }
+}
+
+// TODO SchemaMapping is mostly copied from DataFusion but calls spark_cast
+// instead of arrow cast - can we reduce the amount of code copied here and
make
+// the DataFusion version more extensible?
+
+/// The SchemaMapping struct holds a mapping from the file schema to the table
+/// schema and any necessary type conversions.
+///
+/// Note, because `map_batch` and `map_partial_batch` functions have different
+/// needs, this struct holds two schemas:
+///
+/// 1. The projected **table** schema
+/// 2. The full table schema
+///
+/// [`map_batch`] is used by the ParquetOpener to produce a RecordBatch which
+/// has the projected schema, since that's the schema which is supposed to come
+/// out of the execution of this query. Thus `map_batch` uses
+/// `projected_table_schema` as it can only operate on the projected fields.
+///
+/// [`map_partial_batch`] is used to create a RecordBatch with a schema that
+/// can be used for Parquet predicate pushdown, meaning that it may contain
+/// fields which are not in the projected schema (as the fields that parquet
+/// pushdown filters operate can be completely distinct from the fields that
are
+/// projected (output) out of the ParquetExec). `map_partial_batch` thus uses
+/// `table_schema` to create the resulting RecordBatch (as it could be
operating
+/// on any fields in the schema).
+///
+/// [`map_batch`]: Self::map_batch
+/// [`map_partial_batch`]: Self::map_partial_batch
+#[derive(Debug)]
+pub struct SchemaMapping {
+ /// The schema of the table. This is the expected schema after conversion
+ /// and it should match the schema of the query result.
+ required_schema: SchemaRef,
+ /// Mapping from field index in `projected_table_schema` to index in
+ /// projected file_schema.
+ ///
+ /// They are Options instead of just plain `usize`s because the table could
+ /// have fields that don't exist in the file.
+ field_mappings: Vec<Option<usize>>,
+ /// The entire table schema, as opposed to the projected_table_schema
(which
+ /// only contains the columns that we are projecting out of this query).
+ /// This contains all fields in the table, regardless of if they will be
+ /// projected out or not.
+ table_schema: SchemaRef,
+ /// Spark cast options
+ cast_options: SparkCastOptions,
+}
+
+impl SchemaMapper for SchemaMapping {
+ /// Adapts a `RecordBatch` to match the `projected_table_schema` using the
stored mapping and
+ /// conversions. The produced RecordBatch has a schema that contains only
the projected
+ /// columns, so if one needs a RecordBatch with a schema that references
columns which are not
+ /// in the projected, it would be better to use `map_partial_batch`
+ fn map_batch(&self, batch: RecordBatch) ->
datafusion_common::Result<RecordBatch> {
+ let batch_rows = batch.num_rows();
+ let batch_cols = batch.columns().to_vec();
+
+ let cols = self
+ .required_schema
+ // go through each field in the projected schema
+ .fields()
+ .iter()
+ // and zip it with the index that maps fields from the projected
table schema to the
+ // projected file schema in `batch`
+ .zip(&self.field_mappings)
+ // and for each one...
+ .map(|(field, file_idx)| {
+ file_idx.map_or_else(
+ // If this field only exists in the table, and not in the
file, then we know
+ // that it's null, so just return that.
+ || Ok(new_null_array(field.data_type(), batch_rows)),
+ // However, if it does exist in both, then try to cast it
to the correct output
+ // type
+ |batch_idx| {
+ spark_cast(
+
ColumnarValue::Array(Arc::clone(&batch_cols[batch_idx])),
+ field.data_type(),
+ &self.cast_options,
+ )?
+ .into_array(batch_rows)
+ },
+ )
+ })
+ .collect::<datafusion_common::Result<Vec<_>, _>>()?;
+
+ // Necessary to handle empty batches
+ let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+ let schema = Arc::<Schema>::clone(&self.required_schema);
+ let record_batch = RecordBatch::try_new_with_options(schema, cols,
&options)?;
+ Ok(record_batch)
+ }
+
+ /// Adapts a [`RecordBatch`]'s schema into one that has all the correct
output types and only
+ /// contains the fields that exist in both the file schema and table
schema.
+ ///
+ /// Unlike `map_batch` this method also preserves the columns that
+ /// may not appear in the final output (`projected_table_schema`) but may
+ /// appear in push down predicates
+ fn map_partial_batch(&self, batch: RecordBatch) ->
datafusion_common::Result<RecordBatch> {
+ let batch_cols = batch.columns().to_vec();
+ let schema = batch.schema();
+
+ // for each field in the batch's schema (which is based on a file, not
a table)...
+ let (cols, fields) = schema
+ .fields()
+ .iter()
+ .zip(batch_cols.iter())
+ .flat_map(|(field, batch_col)| {
+ self.table_schema
+ // try to get the same field from the table schema that we
have stored in self
+ .field_with_name(field.name())
+ // and if we don't have it, that's fine, ignore it. This
may occur when we've
+ // created an external table whose fields are a subset of
the fields in this
+ // file, then tried to read data from the file into this
table. If that is the
+ // case here, it's fine to ignore because we don't care
about this field
+ // anyways
+ .ok()
+ // but if we do have it,
+ .map(|table_field| {
+ // try to cast it into the correct output type. we
don't want to ignore this
+ // error, though, so it's propagated.
+ spark_cast(
+ ColumnarValue::Array(Arc::clone(batch_col)),
+ table_field.data_type(),
+ &self.cast_options,
+ )?
+ .into_array(batch_col.len())
+ // and if that works, return the field and column.
+ .map(|new_col| (new_col, table_field.clone()))
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?
+ .into_iter()
+ .unzip::<_, _, Vec<_>, Vec<_>>();
+
+ // Necessary to handle empty batches
+ let options =
RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
+
+ let schema = Arc::new(Schema::new_with_metadata(fields,
schema.metadata().clone()));
+ let record_batch = RecordBatch::try_new_with_options(schema, cols,
&options)?;
+ Ok(record_batch)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::test_common::file_util::get_temp_filename;
+ use crate::{EvalMode, SparkCastOptions, SparkSchemaAdapterFactory};
+ use arrow::array::{Int32Array, StringArray};
+ use arrow::datatypes::{DataType, Field, Schema};
+ use arrow::record_batch::RecordBatch;
+ use arrow_array::UInt32Array;
+ use arrow_schema::SchemaRef;
+ use datafusion::datasource::listing::PartitionedFile;
+ use datafusion::datasource::physical_plan::{FileScanConfig, ParquetExec};
+ use datafusion::execution::object_store::ObjectStoreUrl;
+ use datafusion::execution::TaskContext;
+ use datafusion::physical_plan::ExecutionPlan;
+ use datafusion_common::DataFusionError;
+ use futures::StreamExt;
+ use parquet::arrow::ArrowWriter;
+ use std::fs::File;
+ use std::sync::Arc;
+
+ #[tokio::test]
+ async fn parquet_roundtrip_int_as_string() -> Result<(), DataFusionError> {
+ let file_schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Int32, false),
+ Field::new("name", DataType::Utf8, false),
+ ]));
+
+ let ids = Arc::new(Int32Array::from(vec![1, 2, 3])) as Arc<dyn
arrow::array::Array>;
+ let names = Arc::new(StringArray::from(vec!["Alice", "Bob",
"Charlie"]))
+ as Arc<dyn arrow::array::Array>;
+ let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids,
names])?;
+
+ let required_schema = Arc::new(Schema::new(vec![
+ Field::new("id", DataType::Utf8, false),
+ Field::new("name", DataType::Utf8, false),
+ ]));
+
+ let _ = roundtrip(&batch, required_schema).await?;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn parquet_roundtrip_unsigned_int() -> Result<(), DataFusionError> {
+ let file_schema = Arc::new(Schema::new(vec![Field::new("id",
DataType::UInt32, false)]));
+
+ let ids = Arc::new(UInt32Array::from(vec![1, 2, 3])) as Arc<dyn
arrow::array::Array>;
+ let batch = RecordBatch::try_new(Arc::clone(&file_schema), vec![ids])?;
+
+ let required_schema = Arc::new(Schema::new(vec![Field::new("id",
DataType::Int32, false)]));
+
+ let _ = roundtrip(&batch, required_schema).await?;
+
+ Ok(())
+ }
+
+ /// Create a Parquet file containing a single batch and then read the
batch back using
+ /// the specified required_schema. This will cause the SchemaAdapter code
to be used.
+ async fn roundtrip(
+ batch: &RecordBatch,
+ required_schema: SchemaRef,
+ ) -> Result<RecordBatch, DataFusionError> {
+ let filename = get_temp_filename();
+ let filename =
filename.as_path().as_os_str().to_str().unwrap().to_string();
+ let file = File::create(&filename)?;
+ let mut writer = ArrowWriter::try_new(file,
Arc::clone(&batch.schema()), None)?;
+ writer.write(batch)?;
+ writer.close()?;
+
+ let object_store_url = ObjectStoreUrl::local_filesystem();
+ let file_scan_config = FileScanConfig::new(object_store_url,
required_schema)
+ .with_file_groups(vec![vec![PartitionedFile::from_path(
+ filename.to_string(),
+ )?]]);
+
+ let mut spark_cast_options = SparkCastOptions::new(EvalMode::Legacy,
"UTC", false);
+ spark_cast_options.allow_cast_unsigned_ints = true;
+
+ let parquet_exec = ParquetExec::builder(file_scan_config)
+
.with_schema_adapter_factory(Arc::new(SparkSchemaAdapterFactory::new(
+ spark_cast_options,
+ )))
+ .build();
+
+ let mut stream = parquet_exec
+ .execute(0, Arc::new(TaskContext::default()))
+ .unwrap();
+ stream.next().await.unwrap()
+ }
+}
diff --git a/native/core/src/parquet/util/test_common/file_util.rs
b/native/spark-expr/src/test_common/file_util.rs
similarity index 100%
rename from native/core/src/parquet/util/test_common/file_util.rs
rename to native/spark-expr/src/test_common/file_util.rs
diff --git a/native/core/src/parquet/util/test_common/mod.rs
b/native/spark-expr/src/test_common/mod.rs
similarity index 80%
copy from native/core/src/parquet/util/test_common/mod.rs
copy to native/spark-expr/src/test_common/mod.rs
index e46d7322..efd25a4a 100644
--- a/native/core/src/parquet/util/test_common/mod.rs
+++ b/native/spark-expr/src/test_common/mod.rs
@@ -14,11 +14,4 @@
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
-
pub mod file_util;
-pub mod page_util;
-pub mod rand_gen;
-
-pub use self::rand_gen::{random_bools, random_bytes, random_numbers,
random_numbers_range};
-
-pub use self::file_util::{get_temp_file, get_temp_filename};
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]