Copilot commented on code in PR #3211: URL: https://github.com/apache/datafusion-comet/pull/3211#discussion_r2763832333
########## spark/src/test/scala/org/apache/comet/CometStaticInvokeSuite.scala: ########## @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +class CometStaticInvokeSuite extends CometTestBase { + + test("aes_encrypt basic - verify native execution") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("""INSERT INTO t1 VALUES + ('Spark', '0000111122223333'), + ('SQL', 'abcdefghijklmnop')""") + + val query = """ + SELECT + data, + hex(aes_encrypt(cast(data as binary), cast(key as binary))) as encrypted + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + + val df = sql(query) + val plan = df.queryExecution.executedPlan.toString + assert( + plan.contains("CometProject") || plan.contains("CometNative"), + s"Expected native execution but got Spark fallback:\n$plan") + } + } + + test("aes_encrypt with mode") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT hex(aes_encrypt(cast(data as binary), cast(key as binary), 'GCM')) + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } + + test("aes_encrypt with all parameters") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT hex(aes_encrypt( + cast(data as binary), + cast(key as binary), + 'GCM', + 'DEFAULT', + cast('initialization' as binary), + cast('additional' as binary) + )) + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } + + test("aes_encrypt wrapped in multiple functions") { + withTable("t1") { + sql("CREATE TABLE t1(data STRING, key STRING) USING parquet") + sql("INSERT INTO t1 VALUES ('test', '1234567890123456')") + + val query = """ + SELECT + upper(hex(aes_encrypt(cast(data as binary), cast(key as binary)))) as encrypted, + length(hex(aes_encrypt(cast(data as binary), cast(key as binary)))) as len + FROM t1 + """ + + checkSparkAnswerAndOperator(query) + } + } +} Review Comment: Test coverage is missing important edge cases for the aes_encrypt function. Consider adding tests for: null input/key handling, invalid key lengths (e.g., 8 bytes, 15 bytes, 17 bytes), invalid IV lengths for different modes, unsupported mode/padding combinations, empty input data, and very large input data. These test cases would help ensure robust error handling and compatibility with Spark's behavior. ########## native/spark-expr/src/encryption_funcs/aes_encrypt.rs: ########## @@ -0,0 +1,468 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, BinaryArray, BinaryBuilder, StringArray}; +use datafusion::common::{DataFusionError, Result, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; +use std::sync::Arc; + +use super::cipher_modes::get_cipher_mode; + +pub fn spark_aes_encrypt(args: &[ColumnarValue]) -> Result<ColumnarValue> { + if args.len() < 2 || args.len() > 6 { + return Err(DataFusionError::Execution(format!( + "aes_encrypt expects 2-6 arguments, got {}", + args.len() + ))); + } + + let mode_default = ColumnarValue::Scalar(ScalarValue::Utf8(Some("GCM".to_string()))); + let padding_default = ColumnarValue::Scalar(ScalarValue::Utf8(Some("DEFAULT".to_string()))); + let iv_default = ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![]))); + let aad_default = ColumnarValue::Scalar(ScalarValue::Binary(Some(vec![]))); + + let input_arg = &args[0]; + let key_arg = &args[1]; + let mode_arg = args.get(2).unwrap_or(&mode_default); + let padding_arg = args.get(3).unwrap_or(&padding_default); + let iv_arg = args.get(4).unwrap_or(&iv_default); + let aad_arg = args.get(5).unwrap_or(&aad_default); + + let batch_size = get_batch_size(args)?; + + if batch_size == 1 { + encrypt_scalar(input_arg, key_arg, mode_arg, padding_arg, iv_arg, aad_arg) + } else { + encrypt_batch( + input_arg, + key_arg, + mode_arg, + padding_arg, + iv_arg, + aad_arg, + batch_size, + ) + } +} + +fn encrypt_scalar( + input_arg: &ColumnarValue, + key_arg: &ColumnarValue, + mode_arg: &ColumnarValue, + padding_arg: &ColumnarValue, + iv_arg: &ColumnarValue, + aad_arg: &ColumnarValue, +) -> Result<ColumnarValue> { + let input = match input_arg { + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => opt, + _ => return Err(DataFusionError::Execution("Invalid input type".to_string())), + }; + + let key = match key_arg { + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => opt, + _ => return Err(DataFusionError::Execution("Invalid key type".to_string())), + }; + + if input.is_none() || key.is_none() { + return Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))); + } + + let mode = match mode_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.as_str(), + _ => "GCM", + }; + + let padding = match padding_arg { + ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.as_str(), + _ => "DEFAULT", + }; + + let iv = match iv_arg { + ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) if !v.is_empty() => Some(v.as_slice()), + _ => None, + }; + + let aad = match aad_arg { + ColumnarValue::Scalar(ScalarValue::Binary(Some(v))) if !v.is_empty() => Some(v.as_slice()), + _ => None, + }; + + let cipher = get_cipher_mode(mode, padding)?; + + let encrypted = cipher + .encrypt(input.as_ref().unwrap(), key.as_ref().unwrap(), iv, aad) + .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; + + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encrypted)))) +} + +fn encrypt_batch( + input_arg: &ColumnarValue, + key_arg: &ColumnarValue, + mode_arg: &ColumnarValue, + padding_arg: &ColumnarValue, + iv_arg: &ColumnarValue, + aad_arg: &ColumnarValue, + batch_size: usize, +) -> Result<ColumnarValue> { + let input_array = to_binary_array(input_arg, batch_size)?; + let key_array = to_binary_array(key_arg, batch_size)?; + let mode_array = to_string_array(mode_arg, batch_size)?; + let padding_array = to_string_array(padding_arg, batch_size)?; + let iv_array = to_binary_array(iv_arg, batch_size)?; + let aad_array = to_binary_array(aad_arg, batch_size)?; + + let mut builder = BinaryBuilder::new(); + + for i in 0..batch_size { + if input_array.is_null(i) || key_array.is_null(i) { + builder.append_null(); + continue; + } + + let input = input_array.value(i); + let key = key_array.value(i); + let mode = mode_array.value(i); + let padding = padding_array.value(i); + let iv = if iv_array.is_null(i) || iv_array.value(i).is_empty() { + None + } else { + Some(iv_array.value(i)) + }; + let aad = if aad_array.is_null(i) || aad_array.value(i).is_empty() { + None + } else { + Some(aad_array.value(i)) + }; + + match get_cipher_mode(mode, padding) { + Ok(cipher) => match cipher.encrypt(input, key, iv, aad) { + Ok(encrypted) => builder.append_value(&encrypted), + Err(_) => builder.append_null(), + }, + Err(_) => builder.append_null(), + } Review Comment: Error handling in batch processing silently converts encryption errors to null values. This may mask important errors such as invalid key lengths or unsupported cipher modes. Consider logging errors or propagating them in a way that provides better visibility to users, particularly for cryptographic operations where failure reasons are critical for debugging. ########## native/spark-expr/src/encryption_funcs/crypto_utils.rs: ########## @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::common::DataFusionError; + +#[derive(Debug, PartialEq)] +pub enum CryptoError { + InvalidKeyLength(usize), + InvalidIvLength { expected: usize, actual: usize }, + UnsupportedMode(String, String), + UnsupportedIv(String), + UnsupportedAad(String), + EncryptionFailed(String), +} + +impl From<CryptoError> for DataFusionError { + fn from(err: CryptoError) -> Self { + DataFusionError::Execution(format!("{:?}", err)) + } +} + +pub fn validate_key_length(key: &[u8]) -> Result<(), CryptoError> { + match key.len() { + 16 | 24 | 32 => Ok(()), + len => Err(CryptoError::InvalidKeyLength(len)), + } +} + +pub fn generate_random_iv(length: usize) -> Vec<u8> { + use rand::Rng; + let mut iv = vec![0u8; length]; + rand::rng().fill(&mut iv[..]); Review Comment: The `generate_random_iv` function uses `rand::rng()` which generates cryptographically insecure random numbers. For AES encryption, especially GCM mode, initialization vectors must be cryptographically secure random values. Using non-secure random IVs can compromise the security of the encryption. Consider using `rand::rngs::OsRng` or `rand::thread_rng()` with a cryptographically secure RNG to generate IVs. ```suggestion use rand::rngs::OsRng; use rand::RngCore; let mut iv = vec![0u8; length]; OsRng.fill_bytes(&mut iv[..]); ``` ########## native/spark-expr/src/encryption_funcs/cipher_modes.rs: ########## @@ -0,0 +1,470 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::crypto_utils::{ + generate_random_iv, validate_iv_length, validate_key_length, CryptoError, +}; + +pub trait CipherMode: Send + Sync + std::fmt::Debug { + #[allow(dead_code)] + fn name(&self) -> &str; + #[allow(dead_code)] + fn iv_length(&self) -> usize; + #[allow(dead_code)] + fn supports_aad(&self) -> bool; + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result<Vec<u8>, CryptoError>; +} + +#[derive(Debug)] +pub struct EcbMode; +#[derive(Debug)] +pub struct CbcMode; +#[derive(Debug)] +pub struct GcmMode; + +impl CipherMode for EcbMode { + fn name(&self) -> &str { + "ECB" + } + + fn iv_length(&self) -> usize { + 0 + } + + fn supports_aad(&self) -> bool { + false + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result<Vec<u8>, CryptoError> { + use aes::{Aes128, Aes192, Aes256}; + use cipher::{block_padding::Pkcs7, BlockEncryptMut, KeyInit}; + use ecb::Encryptor; + + validate_key_length(key)?; + + if iv.is_some() { + return Err(CryptoError::UnsupportedIv("ECB".to_string())); + } + if aad.is_some() { + return Err(CryptoError::UnsupportedAad("ECB".to_string())); + } + + let encrypted = match key.len() { + 16 => { + let cipher = Encryptor::<Aes128>::new(key.into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + 24 => { + let cipher = Encryptor::<Aes192>::new(key.into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + 32 => { + let cipher = Encryptor::<Aes256>::new(key.into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + _ => unreachable!("Key length validated above"), + }; + + Ok(encrypted) + } +} + +impl CipherMode for CbcMode { + fn name(&self) -> &str { + "CBC" + } + + fn iv_length(&self) -> usize { + 16 + } + + fn supports_aad(&self) -> bool { + false + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result<Vec<u8>, CryptoError> { + use aes::{Aes128, Aes192, Aes256}; + use cbc::cipher::{block_padding::Pkcs7, BlockEncryptMut, KeyIvInit}; + use cbc::Encryptor; + + validate_key_length(key)?; + + if aad.is_some() { + return Err(CryptoError::UnsupportedAad("CBC".to_string())); + } + + let iv_bytes = match iv { + Some(iv) => { + validate_iv_length(iv, 16)?; + iv.to_vec() + } + None => generate_random_iv(16), + }; + + let ciphertext = match key.len() { + 16 => { + let cipher = Encryptor::<Aes128>::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + 24 => { + let cipher = Encryptor::<Aes192>::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + 32 => { + let cipher = Encryptor::<Aes256>::new(key.into(), iv_bytes.as_slice().into()); + cipher.encrypt_padded_vec_mut::<Pkcs7>(input) + } + _ => unreachable!("Key length validated above"), + }; + + let mut result = iv_bytes; + result.extend_from_slice(&ciphertext); + Ok(result) + } +} + +impl CipherMode for GcmMode { + fn name(&self) -> &str { + "GCM" + } + + fn iv_length(&self) -> usize { + 12 + } + + fn supports_aad(&self) -> bool { + true + } + + fn encrypt( + &self, + input: &[u8], + key: &[u8], + iv: Option<&[u8]>, + aad: Option<&[u8]>, + ) -> Result<Vec<u8>, CryptoError> { + use aes_gcm::aead::{Aead, Payload}; + use aes_gcm::{Aes128Gcm, Aes256Gcm, KeyInit, Nonce}; + + validate_key_length(key)?; + + let iv_bytes = match iv { + Some(iv) => { + validate_iv_length(iv, 12)?; + iv.to_vec() + } + None => generate_random_iv(12), + }; + + let nonce = Nonce::from_slice(&iv_bytes); + + let ciphertext = match key.len() { + 16 => { + let cipher = Aes128Gcm::new(key.into()); + let payload = match aad { + Some(aad_data) => Payload { + msg: input, + aad: aad_data, + }, + None => Payload { + msg: input, + aad: &[], + }, + }; + cipher + .encrypt(nonce, payload) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))? + } + 24 | 32 => { + let cipher = Aes256Gcm::new(key.into()); + let payload = match aad { + Some(aad_data) => Payload { + msg: input, + aad: aad_data, + }, + None => Payload { + msg: input, + aad: &[], + }, + }; + cipher + .encrypt(nonce, payload) + .map_err(|e| CryptoError::EncryptionFailed(e.to_string()))? + } + _ => unreachable!("Key length validated above"), Review Comment: The GCM mode implementation maps both 24-byte and 32-byte keys to `Aes256Gcm`. However, the standard AES-GCM supports AES-128, AES-192, and AES-256, which require 16, 24, and 32-byte keys respectively. Using `Aes256Gcm` for a 24-byte key is incorrect. The `aes-gcm` crate provides `Aes192Gcm` for 24-byte keys. This should be updated to use the correct cipher implementation for each key size to ensure proper encryption standards compliance. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
