slfan1989 commented on code in PR #1514: URL: https://github.com/apache/auron/pull/1514#discussion_r2464313269
########## native-engine/datafusion-ext-functions/src/spark_pow.rs: ########## @@ -0,0 +1,384 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::{Array, ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array}, + datatypes::DataType, +}; +use datafusion::{ + common::{Result, ScalarValue}, + error::DataFusionError, + physical_plan::ColumnarValue, +}; + +/// Spark-like `pow` for Short/Int/Long/Float/Double: +/// - Scalar×Scalar: +/// * Convert both to f64; if exponent is integer (i16/i32/i64) use `powi`, +/// otherwise `powf`. +/// * Special case: `0 ** negative` => `+∞` (matches array path and common +/// DB behavior). +/// * Returns `Float64`. +/// - Array×Array: +/// * Convert both arrays to `Vec<Option<f64>>`, compute element-wise, nulls +/// propagate. +/// * Special case: `0 ** negative` => `+∞`. +/// * Returns `Float64Array`. +pub fn spark_pow(args: &[ColumnarValue]) -> Result<ColumnarValue> { + if args.len() != 2 { + return Err(DataFusionError::Plan( + "Expected 2 arguments for pow function".to_string(), + )); + } + + match (&args[0], &args[1]) { + (ColumnarValue::Scalar(b), ColumnarValue::Scalar(e)) => scalar_pow(b, e), + (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => { + let out = pow_arrays(lhs.as_ref(), rhs.as_ref())?; + Ok(ColumnarValue::Array(out)) + } + _ => Err(DataFusionError::Plan( + "pow expects both arguments to be both scalars or both arrays".into(), + )), + } +} + +// ----------------------------- Scalar × Scalar ----------------------------- + +fn scalar_pow(base: &ScalarValue, exp: &ScalarValue) -> Result<ColumnarValue> { + let b = scalar_to_f64(base) + .ok_or_else(|| DataFusionError::Plan("Unsupported base type for pow".to_string()))?; + + // integer exponent ⇒ prefer powi (faster/more stable). powi accepts negative + // i32. + if let Some(e_i32) = scalar_integer_exponent_i32(exp)? { + // Special case: 0 ** negative => +∞ + if b == 0.0 && e_i32 < 0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::INFINITY, + )))); + } + let result = b.powi(e_i32); + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))); + } + + // Floating exponent path + let e = scalar_to_f64(exp) + .ok_or_else(|| DataFusionError::Plan("Unsupported exponent type for pow".to_string()))?; + + // Special case: 0 ** negative => +∞ + if b == 0.0 && e < 0.0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( + f64::INFINITY, + )))); + } + + let result = b.powf(e); + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(result)))) +} + +/// Convert supported ScalarValue to f64 (lossy for integers by design). +fn scalar_to_f64(v: &ScalarValue) -> Option<f64> { + match v { + ScalarValue::Int16(Some(x)) => Some(*x as f64), + ScalarValue::Int32(Some(x)) => Some(*x as f64), + ScalarValue::Int64(Some(x)) => Some(*x as f64), + ScalarValue::Float32(Some(x)) => Some(*x as f64), + ScalarValue::Float64(Some(x)) => Some(*x), + _ => None, + } +} + +/// If `v` is an integer scalar exponent, return it as i32 (error if i64 out of +/// i32 range). +fn scalar_integer_exponent_i32(v: &ScalarValue) -> Result<Option<i32>> { + let out = match v { + ScalarValue::Int16(Some(x)) => Some(*x as i32), + ScalarValue::Int32(Some(x)) => Some(*x), + ScalarValue::Int64(Some(x)) => { + let e = *x; + if e < i32::MIN as i64 || e > i32::MAX as i64 { + return Err(DataFusionError::Plan(format!( + "Exponent {} outside i32 range for powi", + e + ))); + } + Some(e as i32) + } + _ => None, + }; + Ok(out) +} + +// ----------------------------- Array × Array ------------------------------ + +/// Convert both arrays to f64 vectors and compute element-wise powf. +/// - Returns Float64Array +/// - Nulls propagate +/// - 0 ** negative => +∞ +/// - Length must match +fn pow_arrays(lhs: &dyn Array, rhs: &dyn Array) -> Result<ArrayRef> { + let lvals = to_f64_vec(lhs)?; Review Comment: Thank you for helping review the code, I will continue to improve it. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
