This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new 27aecb628 feat: Translate Hadoop S3A configurations to object_store configurations (#1817) 27aecb628 is described below commit 27aecb628c4a16163a03301109c9862e2730f6d0 Author: Kristin Cowalcijk <b...@wherobots.com> AuthorDate: Tue Jun 3 10:50:02 2025 +0800 feat: Translate Hadoop S3A configurations to object_store configurations (#1817) --- native/Cargo.lock | 590 +++++- native/Cargo.toml | 2 + native/core/Cargo.toml | 2 + native/core/src/execution/planner.rs | 15 +- native/core/src/parquet/mod.rs | 2 + native/core/src/parquet/objectstore/mod.rs | 18 + native/core/src/parquet/objectstore/s3.rs | 1952 ++++++++++++++++++++ native/core/src/parquet/parquet_support.rs | 22 +- native/proto/src/proto/operator.proto | 7 + pom.xml | 16 + spark/pom.xml | 8 + .../apache/comet/objectstore/NativeConfig.scala | 79 + .../org/apache/comet/serde/QueryPlanSerde.scala | 20 + .../comet/objectstore/NativeConfigSuite.scala | 73 + .../comet/parquet/ParquetReadFromS3Suite.scala | 126 ++ 15 files changed, 2901 insertions(+), 31 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index d597eed94..375e1125a 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -375,6 +375,353 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-config" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02a18fd934af6ae7ca52410d4548b98eb895aab0f1ea417d168d85db1434a141" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-sdk-sso", + "aws-sdk-ssooidc", + "aws-sdk-sts", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "hex", + "http 1.3.1", + "ring", + "time", + "tokio", + "tracing", + "url", + "zeroize", +] + +[[package]] +name = "aws-credential-types" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-lc-rs" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +dependencies = [ + "bindgen 0.69.5", + "cc", + "cmake", + "dunce", + "fs_extra", +] + +[[package]] +name = "aws-runtime" +version = "1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c4063282c69991e57faab9e5cb21ae557e59f5b0fb285c196335243df8dc25c" +dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "tracing", + "uuid", +] + +[[package]] +name = "aws-sdk-sso" +version = "1.71.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a4fd09d6e863655d99cd2260f271c6d1030dc6bfad68e19e126d2e4c8ceb18" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-ssooidc" +version = "1.72.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3224ab02ebb3074467a33d57caf6fcb487ca36f3697fdd381b0428dc72380696" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sdk-sts" +version = "1.72.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6933f189ed1255e78175fbd73fb200c0aae7240d220ed3346f567b0ddca3083" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-query", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-smithy-xml", + "aws-types", + "fastrand", + "http 0.2.12", + "regex-lite", + "tracing", +] + +[[package]] +name = "aws-sigv4" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3734aecf9ff79aa401a6ca099d076535ab465ff76b46440cf567c8e70b65dc13" +dependencies = [ + "aws-credential-types", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.3.1", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "aws-smithy-http" +version = "0.62.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e44697a9bded898dcd0b1cb997430d949b87f4f8940d91023ae9062bf218250" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "h2", + "http 1.3.1", + "hyper", + "hyper-rustls", + "hyper-util", + "pin-project-lite", + "rustls", + "rustls-native-certs", + "rustls-pki-types", + "tokio", + "tower", + "tracing", +] + +[[package]] +name = "aws-smithy-json" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" +dependencies = [ + "aws-smithy-types", +] + +[[package]] +name = "aws-smithy-observability" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +dependencies = [ + "aws-smithy-runtime-api", +] + +[[package]] +name = "aws-smithy-query" +version = "0.60.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2fbd61ceb3fe8a1cb7352e42689cec5335833cd9f94103a61e98f9bb61c64bb" +dependencies = [ + "aws-smithy-types", + "urlencoding", +] + +[[package]] +name = "aws-smithy-runtime" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" +dependencies = [ + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "fastrand", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "http-body 1.0.1", + "pin-project-lite", + "pin-utils", + "tokio", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1e5d9e3a80a18afa109391fb5ad09c3daf887b516c6fd805a157c6ea7994a57" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.3.1", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + +[[package]] +name = "aws-smithy-types" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40076bd09fadbc12d5e026ae080d0930defa606856186e31d83ccc6a255eeaf3" +dependencies = [ + "base64-simd", + "bytes", + "bytes-utils", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "http-body 1.0.1", + "http-body-util", + "itoa", + "num-integer", + "pin-project-lite", + "pin-utils", + "ryu", + "serde", + "time", +] + +[[package]] +name = "aws-smithy-xml" +version = "0.60.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab0b0166827aa700d3dc519f72f8b3a91c35d0b8d042dc5d643a91e6f80648fc" +dependencies = [ + "xmlparser", +] + +[[package]] +name = "aws-types" +version = "1.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -396,6 +743,16 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64-simd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195" +dependencies = [ + "outref", + "vsimd", +] + [[package]] name = "bigdecimal" version = "0.4.8" @@ -431,6 +788,29 @@ dependencies = [ "which", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.1", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.101", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -519,6 +899,16 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "bytes-utils" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35" +dependencies = [ + "bytes", + "either", +] + [[package]] name = "cast" version = "0.3.0" @@ -660,6 +1050,15 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "combine" version = "4.6.7" @@ -966,6 +1365,8 @@ dependencies = [ "arrow", "assertables", "async-trait", + "aws-config", + "aws-credential-types", "bytes", "crc32fast", "criterion", @@ -1529,6 +1930,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "deranged" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" +dependencies = [ + "powerfmt", +] + [[package]] name = "derivative" version = "2.2.0" @@ -1568,6 +1978,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "either" version = "1.15.0" @@ -1682,7 +2098,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25f164ff6334da016dffd1c29a3c05b81c35b857ef829d3fa9e58ae8d3e6f76b" dependencies = [ - "bindgen", + "bindgen 0.64.0", "cc", "lazy_static", "libc", @@ -1696,7 +2112,7 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0f38e500596a428817fd4fd8a9a21da32f4edb3250e87886039670b12ea02f5d" dependencies = [ - "bindgen", + "bindgen 0.64.0", "cc", "lazy_static", "libc", @@ -1704,6 +2120,12 @@ dependencies = [ "url", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures" version = "0.3.31" @@ -1853,7 +2275,7 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http", + "http 1.3.1", "indexmap", "slab", "tokio", @@ -1911,6 +2333,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "home" version = "0.5.11" @@ -1920,6 +2351,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "http" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http" version = "1.3.1" @@ -1931,6 +2373,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http-body" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" +dependencies = [ + "bytes", + "http 0.2.12", + "pin-project-lite", +] + [[package]] name = "http-body" version = "1.0.1" @@ -1938,7 +2391,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http", + "http 1.3.1", ] [[package]] @@ -1949,8 +2402,8 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "pin-project-lite", ] @@ -1976,8 +2429,8 @@ dependencies = [ "futures-channel", "futures-util", "h2", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "httparse", "itoa", "pin-project-lite", @@ -1992,7 +2445,7 @@ version = "0.27.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a01595e11bdcec50946522c32dde3fc6914743000a68b93000965f2f02406d" dependencies = [ - "http", + "http 1.3.1", "hyper", "hyper-util", "rustls", @@ -2014,8 +2467,8 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "hyper", "ipnet", "libc", @@ -2228,6 +2681,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -2443,9 +2905,9 @@ checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765" dependencies = [ "autocfg", "scopeguard", @@ -2635,6 +3097,12 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-format" version = "0.4.4" @@ -2707,7 +3175,7 @@ dependencies = [ "chrono", "form_urlencoded", "futures", - "http", + "http 1.3.1", "http-body-util", "httparse", "humantime", @@ -2760,11 +3228,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "outref" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" + [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", "parking_lot_core", @@ -2772,9 +3246,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "bc838d2a56b5b1a6c25f55575dfc605fabb63bb2365f6c2353ef9159aa69e4a5" dependencies = [ "cfg-if", "libc", @@ -2961,6 +3435,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "pprof" version = "0.14.0" @@ -3317,6 +3797,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.8.5" @@ -3334,8 +3820,8 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "http-body-util", "hyper", "hyper-rustls", @@ -3450,6 +3936,7 @@ version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ + "aws-lc-rs", "once_cell", "ring", "rustls-pki-types", @@ -3495,6 +3982,7 @@ version = "0.103.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3655,6 +4143,15 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook-registry" +version = "1.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" +dependencies = [ + "libc", +] + [[package]] name = "simd-adler32" version = "0.3.7" @@ -3933,6 +4430,36 @@ dependencies = [ "tikv-jemalloc-sys", ] +[[package]] +name = "time" +version = "0.3.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" + +[[package]] +name = "time-macros" +version = "0.2.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -3989,6 +4516,7 @@ dependencies = [ "mio", "parking_lot", "pin-project-lite", + "signal-hook-registry", "socket2", "tokio-macros", "windows-sys 0.52.0", @@ -4052,8 +4580,8 @@ dependencies = [ "bitflags 2.9.1", "bytes", "futures-util", - "http", - "http-body", + "http 1.3.1", + "http-body 1.0.1", "iri-string", "pin-project-lite", "tower", @@ -4194,6 +4722,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4217,6 +4751,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "vsimd" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" + [[package]] name = "walkdir" version = "2.5.0" @@ -4684,6 +5224,12 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "xmlparser" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4" + [[package]] name = "yoke" version = "0.8.0" diff --git a/native/Cargo.toml b/native/Cargo.toml index 7142ffa90..970572d81 100644 --- a/native/Cargo.toml +++ b/native/Cargo.toml @@ -51,6 +51,8 @@ regex = "1.9.6" thiserror = "2" object_store = { version = "0.12.0", features = ["gcp", "azure", "aws", "http"] } url = "2.2" +aws-config = "1.6.3" +aws-credential-types = "1.2.3" [profile.release] debug = true diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 5e636dbe5..cc39b174c 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -69,6 +69,8 @@ datafusion-comet-spark-expr = { workspace = true } datafusion-comet-proto = { workspace = true } object_store = { workspace = true } url = { workspace = true } +aws-config = { workspace = true } +aws-credential-types = { workspace = true } parking_lot = "0.12.3" datafusion-comet-objectstore-hdfs = { path = "../hdfs", optional = true, default-features = false, features = ["hdfs"] } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bbfdc4f35..578c24e34 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -69,7 +69,7 @@ use datafusion_comet_spark_expr::{create_comet_physical_fun, create_negate_expr, use crate::execution::operators::ExecutionError::GeneralError; use crate::execution::shuffle::CompressionCodec; use crate::execution::spark_plan::SparkPlan; -use crate::parquet::parquet_support::prepare_object_store; +use crate::parquet::parquet_support::prepare_object_store_with_configs; use datafusion::common::scalar::ScalarStructBuilder; use datafusion::common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter}, @@ -1156,8 +1156,17 @@ impl PhysicalPlanner { .and_then(|f| f.partitioned_file.first()) .map(|f| f.file_path.clone()) .ok_or(GeneralError("Failed to locate file".to_string()))?; - let (object_store_url, _) = - prepare_object_store(self.session_ctx.runtime_env(), one_file)?; + + let object_store_options: HashMap<String, String> = scan + .object_store_options + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + let (object_store_url, _) = prepare_object_store_with_configs( + self.session_ctx.runtime_env(), + one_file, + &object_store_options, + )?; // Generate file groups let mut file_groups: Vec<Vec<PartitionedFile>> = diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index b24591e9d..d5a8fa2b8 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -26,6 +26,8 @@ pub mod parquet_support; pub mod read; pub mod schema_adapter; +mod objectstore; + use std::task::Poll; use std::{boxed::Box, ptr::NonNull, sync::Arc}; diff --git a/native/core/src/parquet/objectstore/mod.rs b/native/core/src/parquet/objectstore/mod.rs new file mode 100644 index 000000000..bedae08f6 --- /dev/null +++ b/native/core/src/parquet/objectstore/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod s3; diff --git a/native/core/src/parquet/objectstore/s3.rs b/native/core/src/parquet/objectstore/s3.rs new file mode 100644 index 000000000..c1301272f --- /dev/null +++ b/native/core/src/parquet/objectstore/s3.rs @@ -0,0 +1,1952 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use log::{debug, error}; +use std::collections::HashMap; +use url::Url; + +use std::{ + sync::{Arc, RwLock}, + time::{Duration, SystemTime}, +}; + +use crate::execution::jni_api::get_runtime; +use async_trait::async_trait; +use aws_config::{ + ecs::EcsCredentialsProvider, environment::EnvironmentVariableCredentialsProvider, + imds::credentials::ImdsCredentialsProvider, meta::credentials::CredentialsProviderChain, + provider_config::ProviderConfig, sts::AssumeRoleProvider, + web_identity_token::WebIdentityTokenCredentialsProvider, BehaviorVersion, +}; +use aws_credential_types::{ + provider::{error::CredentialsError, ProvideCredentials}, + Credentials, +}; +use object_store::{ + aws::{resolve_bucket_region, AmazonS3Builder, AmazonS3ConfigKey, AwsCredential}, + path::Path, + ClientOptions, CredentialProvider, ObjectStore, ObjectStoreScheme, +}; + +/// Creates an S3 object store using options specified as Hadoop S3A configurations. +/// +/// # Arguments +/// +/// * `url` - The URL of the S3 object to access. +/// * `configs` - The Hadoop S3A configurations to use for building the object store. +/// * `min_ttl` - Time buffer before credential expiry when refresh should be triggered. +/// +/// # Returns +/// +/// * `(Box<dyn ObjectStore>, Path)` - The object store and path of the S3 object store. +/// +pub fn create_store( + url: &Url, + configs: &HashMap<String, String>, + min_ttl: Duration, +) -> Result<(Box<dyn ObjectStore>, Path), object_store::Error> { + let (scheme, path) = ObjectStoreScheme::parse(url)?; + if scheme != ObjectStoreScheme::AmazonS3 { + return Err(object_store::Error::Generic { + store: "S3", + source: format!("Scheme of URL is not S3: {}", url).into(), + }); + } + let path = Path::from_url_path(path)?; + + let mut builder = AmazonS3Builder::new() + .with_url(url.to_string()) + .with_allow_http(true); + let bucket = url.host_str().ok_or_else(|| object_store::Error::Generic { + store: "S3", + source: "Missing bucket name in S3 URL".into(), + })?; + + let credential_provider = + get_runtime().block_on(build_credential_provider(configs, bucket, min_ttl))?; + builder = match credential_provider { + Some(provider) => builder.with_credentials(Arc::new(provider)), + None => builder.with_skip_signature(true), + }; + + let s3_configs = extract_s3_config_options(configs, bucket); + debug!("S3 configs for bucket {}: {:?}", bucket, s3_configs); + + // When using the default AWS S3 endpoint (no custom endpoint configured), a valid region + // is required. If no region is explicitly configured, attempt to auto-resolve it by + // making a HeadBucket request to determine the bucket's region. + if !s3_configs.contains_key(&AmazonS3ConfigKey::Endpoint) + && !s3_configs.contains_key(&AmazonS3ConfigKey::Region) + { + let region = + get_runtime().block_on(resolve_bucket_region(bucket, &ClientOptions::new()))?; + debug!("resolved region: {:?}", region); + builder = builder.with_config(AmazonS3ConfigKey::Region, region.to_string()); + } + + for (key, value) in s3_configs { + builder = builder.with_config(key, value); + } + + let object_store = builder.build()?; + + Ok((Box::new(object_store), path)) +} + +/// Extracts S3 configuration options from Hadoop S3A configurations and returns them +/// as a HashMap of (AmazonS3ConfigKey, String) pairs that can be applied to an AmazonS3Builder. +/// +/// # Arguments +/// +/// * `configs` - The Hadoop S3A configurations to extract from. +/// * `bucket` - The bucket name to extract configurations for. +/// +/// # Returns +/// +/// * `HashMap<AmazonS3ConfigKey, String>` - The extracted S3 configuration options. +/// +fn extract_s3_config_options( + configs: &HashMap<String, String>, + bucket: &str, +) -> HashMap<AmazonS3ConfigKey, String> { + let mut s3_configs = HashMap::new(); + + // Extract region configuration + if let Some(region) = get_config_trimmed(configs, bucket, "endpoint.region") { + s3_configs.insert(AmazonS3ConfigKey::Region, region.to_string()); + } + + // Extract and handle path style access (virtual hosted style) + let mut virtual_hosted_style_request = false; + if let Some(path_style) = get_config_trimmed(configs, bucket, "path.style.access") { + virtual_hosted_style_request = path_style.to_lowercase() == "true"; + s3_configs.insert( + AmazonS3ConfigKey::VirtualHostedStyleRequest, + virtual_hosted_style_request.to_string(), + ); + } + + // Extract endpoint configuration and modify if virtual hosted style is enabled + if let Some(endpoint) = get_config_trimmed(configs, bucket, "endpoint") { + let normalized_endpoint = + normalize_endpoint(endpoint, bucket, virtual_hosted_style_request); + if let Some(endpoint) = normalized_endpoint { + s3_configs.insert(AmazonS3ConfigKey::Endpoint, endpoint); + } + } + + // Extract request payer configuration + if let Some(requester_pays) = get_config_trimmed(configs, bucket, "requester.pays.enabled") { + let requester_pays_enabled = requester_pays.to_lowercase() == "true"; + s3_configs.insert( + AmazonS3ConfigKey::RequestPayer, + requester_pays_enabled.to_string(), + ); + } + + s3_configs +} + +fn normalize_endpoint( + endpoint: &str, + bucket: &str, + virtual_hosted_style_request: bool, +) -> Option<String> { + if endpoint.is_empty() { + return None; + } + + // This is the default Hadoop S3A configuration. Explicitly specifying this endpoint will lead to HTTP + // request failures when using object_store crate, so we ignore it and let object_store crate + // use the default endpoint. + if endpoint == "s3.amazonaws.com" { + return None; + } + + let endpoint = if !endpoint.starts_with("http://") && !endpoint.starts_with("https://") { + format!("https://{}", endpoint) + } else { + endpoint.to_string() + }; + + if virtual_hosted_style_request { + if endpoint.ends_with("/") { + Some(format!("{}{}", endpoint, bucket)) + } else { + Some(format!("{}/{}", endpoint, bucket)) + } + } else { + Some(endpoint) // Avoid extra to_string() call since endpoint is already a String + } +} + +fn get_config<'a>( + configs: &'a HashMap<String, String>, + bucket: &str, + property: &str, +) -> Option<&'a String> { + let per_bucket_key = format!("fs.s3a.bucket.{}.{}", bucket, property); + configs.get(&per_bucket_key).or_else(|| { + let global_key = format!("fs.s3a.{}", property); + configs.get(&global_key) + }) +} + +fn get_config_trimmed<'a>( + configs: &'a HashMap<String, String>, + bucket: &str, + property: &str, +) -> Option<&'a str> { + get_config(configs, bucket, property).map(|s| s.trim()) +} + +// Hadoop S3A credential provider constants +const HADOOP_IAM_INSTANCE: &str = "org.apache.hadoop.fs.s3a.auth.IAMInstanceCredentialsProvider"; +const HADOOP_SIMPLE: &str = "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider"; +const HADOOP_TEMPORARY: &str = "org.apache.hadoop.fs.s3a.TemporaryAWSCredentialsProvider"; +const HADOOP_ASSUMED_ROLE: &str = "org.apache.hadoop.fs.s3a.auth.AssumedRoleCredentialProvider"; +const HADOOP_ANONYMOUS: &str = "org.apache.hadoop.fs.s3a.AnonymousAWSCredentialsProvider"; + +// AWS SDK credential provider constants +const AWS_CONTAINER_CREDENTIALS: &str = + "software.amazon.awssdk.auth.credentials.ContainerCredentialsProvider"; +const AWS_CONTAINER_CREDENTIALS_V1: &str = "com.amazonaws.auth.ContainerCredentialsProvider"; +const AWS_EC2_CONTAINER_CREDENTIALS: &str = + "com.amazonaws.auth.EC2ContainerCredentialsProviderWrapper"; +const AWS_INSTANCE_PROFILE: &str = + "software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider"; +const AWS_INSTANCE_PROFILE_V1: &str = "com.amazonaws.auth.InstanceProfileCredentialsProvider"; +const AWS_ENVIRONMENT: &str = + "software.amazon.awssdk.auth.credentials.EnvironmentVariableCredentialsProvider"; +const AWS_ENVIRONMENT_V1: &str = "com.amazonaws.auth.EnvironmentVariableCredentialsProvider"; +const AWS_WEB_IDENTITY: &str = + "software.amazon.awssdk.auth.credentials.WebIdentityTokenFileCredentialsProvider"; +const AWS_WEB_IDENTITY_V1: &str = "com.amazonaws.auth.WebIdentityTokenCredentialsProvider"; +const AWS_ANONYMOUS: &str = "software.amazon.awssdk.auth.credentials.AnonymousCredentialsProvider"; +const AWS_ANONYMOUS_V1: &str = "com.amazonaws.auth.AnonymousAWSCredentials"; + +/// Builds an AWS credential provider from the given configurations. +/// It first checks if the credential provider is anonymous, and if so, returns `None`. +/// Otherwise, it builds a [CachedAwsCredentialProvider] from the given configurations. +/// +/// # Arguments +/// +/// * `configs` - The Hadoop S3A configurations to use for building the credential provider. +/// * `bucket` - The bucket to build the credential provider for. +/// * `min_ttl` - Time buffer before credential expiry when refresh should be triggered. +/// +/// # Returns +/// +/// * `None` - If the credential provider is anonymous. +/// * `Some(CachedAwsCredentialProvider)` - If the credential provider is not anonymous. +/// +async fn build_credential_provider( + configs: &HashMap<String, String>, + bucket: &str, + min_ttl: Duration, +) -> Result<Option<CachedAwsCredentialProvider>, object_store::Error> { + let aws_credential_provider_names = + get_config_trimmed(configs, bucket, "aws.credentials.provider"); + let aws_credential_provider_names = + aws_credential_provider_names.map_or(Vec::new(), |s| parse_credential_provider_names(s)); + if aws_credential_provider_names + .iter() + .any(|name| is_anonymous_credential_provider(name)) + { + if aws_credential_provider_names.len() > 1 { + return Err(object_store::Error::Generic { + store: "S3", + source: + "Anonymous credential provider cannot be mixed with other credential providers" + .into(), + }); + } + return Ok(None); + } + let provider_metadata = build_chained_aws_credential_provider_metadata( + aws_credential_provider_names, + configs, + bucket, + )?; + debug!( + "Credential providers for S3 bucket {}: {}", + bucket, + provider_metadata.simple_string() + ); + let provider = provider_metadata.create_credential_provider().await?; + Ok(Some(CachedAwsCredentialProvider::new( + provider, + provider_metadata, + min_ttl, + ))) +} + +fn parse_credential_provider_names(aws_credential_provider_names: &str) -> Vec<&str> { + aws_credential_provider_names + .split(',') + .map(|s| s.trim()) + .filter(|s| !s.is_empty()) + .collect::<Vec<&str>>() +} + +fn is_anonymous_credential_provider(credential_provider_name: &str) -> bool { + [HADOOP_ANONYMOUS, AWS_ANONYMOUS_V1, AWS_ANONYMOUS].contains(&credential_provider_name) +} + +fn build_chained_aws_credential_provider_metadata( + credential_provider_names: Vec<&str>, + configs: &HashMap<String, String>, + bucket: &str, +) -> Result<CredentialProviderMetadata, object_store::Error> { + if credential_provider_names.is_empty() { + // Use the default credential provider chain. This is actually more permissive than + // the default Hadoop S3A FileSystem behavior, which only uses + // TemporaryAWSCredentialsProvider, SimpleAWSCredentialsProvider, + // EnvironmentVariableCredentialsProvider and IAMInstanceCredentialsProvider + return Ok(CredentialProviderMetadata::Default); + } + + // Safety: credential_provider_names is not empty, taking its first element is safe + let provider_name = credential_provider_names[0]; + let provider_metadata = build_aws_credential_provider_metadata(provider_name, configs, bucket)?; + if credential_provider_names.len() == 1 { + // No need to chain the provider as there's only one provider + return Ok(provider_metadata); + } + + // More than one credential provider names were specified, we need to chain them together + let mut metadata_vec = vec![provider_metadata]; + for provider_name in credential_provider_names[1..].iter() { + let provider_metadata = + build_aws_credential_provider_metadata(provider_name, configs, bucket)?; + metadata_vec.push(provider_metadata); + } + + Ok(CredentialProviderMetadata::Chain(metadata_vec)) +} + +fn build_aws_credential_provider_metadata( + credential_provider_name: &str, + configs: &HashMap<String, String>, + bucket: &str, +) -> Result<CredentialProviderMetadata, object_store::Error> { + match credential_provider_name { + AWS_CONTAINER_CREDENTIALS + | AWS_CONTAINER_CREDENTIALS_V1 + | AWS_EC2_CONTAINER_CREDENTIALS => Ok(CredentialProviderMetadata::Ecs), + AWS_INSTANCE_PROFILE | AWS_INSTANCE_PROFILE_V1 => Ok(CredentialProviderMetadata::Imds), + HADOOP_IAM_INSTANCE => Ok(CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Ecs, + CredentialProviderMetadata::Imds, + ])), + AWS_ENVIRONMENT_V1 | AWS_ENVIRONMENT => Ok(CredentialProviderMetadata::Environment), + HADOOP_SIMPLE | HADOOP_TEMPORARY => { + build_static_credential_provider_metadata(credential_provider_name, configs, bucket) + } + HADOOP_ASSUMED_ROLE => build_assume_role_credential_provider_metadata(configs, bucket), + AWS_WEB_IDENTITY_V1 | AWS_WEB_IDENTITY => Ok(CredentialProviderMetadata::WebIdentity), + _ => Err(object_store::Error::Generic { + store: "S3", + source: format!( + "Unsupported credential provider: {}", + credential_provider_name + ) + .into(), + }), + } +} + +fn build_static_credential_provider_metadata( + credential_provider_name: &str, + configs: &HashMap<String, String>, + bucket: &str, +) -> Result<CredentialProviderMetadata, object_store::Error> { + let access_key_id = get_config_trimmed(configs, bucket, "access.key"); + let secret_access_key = get_config_trimmed(configs, bucket, "secret.key"); + let session_token = if credential_provider_name == HADOOP_TEMPORARY { + get_config_trimmed(configs, bucket, "session.token") + } else { + None + }; + + // Allow static credential provider creation even when access/secret keys are missing. + // This maintains compatibility with Hadoop S3A FileSystem, whose default credential chain + // includes TemporaryAWSCredentialsProvider. Missing credentials won't prevent other + // providers in the chain from working - this provider will error only when accessed. + let mut is_valid = access_key_id.is_some() && secret_access_key.is_some(); + if credential_provider_name == HADOOP_TEMPORARY { + is_valid = is_valid && session_token.is_some(); + }; + + Ok(CredentialProviderMetadata::Static { + is_valid, + access_key: access_key_id.unwrap_or("").to_string(), + secret_key: secret_access_key.unwrap_or("").to_string(), + session_token: session_token.map(|s| s.to_string()), + }) +} + +fn build_assume_role_credential_provider_metadata( + configs: &HashMap<String, String>, + bucket: &str, +) -> Result<CredentialProviderMetadata, object_store::Error> { + let base_provider_names = + get_config_trimmed(configs, bucket, "assumed.role.credentials.provider") + .map(|s| parse_credential_provider_names(s)); + let base_provider_names = if let Some(v) = base_provider_names { + if v.iter().any(|name| is_anonymous_credential_provider(name)) { + return Err(object_store::Error::Generic { + store: "S3", + source: "Anonymous credential provider cannot be used as assumed role credential provider".into(), + }); + } + v + } else { + // If credential provider for performing assume role operation is not specified, we'll use simple + // credential provider first, and fallback to environment variable credential provider. This is the + // same behavior as Hadoop S3A FileSystem. + vec![HADOOP_SIMPLE, AWS_ENVIRONMENT] + }; + + let role_arn = get_config_trimmed(configs, bucket, "assumed.role.arn").ok_or( + object_store::Error::Generic { + store: "S3", + source: "Missing required assume role ARN configuration".into(), + }, + )?; + let default_session_name = "comet-parquet-s3".to_string(); + let session_name = get_config_trimmed(configs, bucket, "assumed.role.session.name") + .unwrap_or(&default_session_name); + + let base_provider_metadata = + build_chained_aws_credential_provider_metadata(base_provider_names, configs, bucket)?; + Ok(CredentialProviderMetadata::AssumeRole { + role_arn: role_arn.to_string(), + session_name: session_name.to_string(), + base_provider_metadata: Box::new(base_provider_metadata), + }) +} + +/// A caching wrapper around AWS credential providers that implements the object_store `CredentialProvider` trait. +/// +/// This struct bridges AWS SDK credential providers (`ProvideCredentials`) with the object_store +/// crate's `CredentialProvider` trait, enabling seamless use of AWS credentials with object_store's +/// S3 implementation. It also provides credential caching to improve performance and reduce the +/// frequency of credential refresh operations. Many AWS credential providers (like IMDS, ECS, STS +/// assume role) involve network calls or complex authentication flows that can be expensive to +/// repeat constantly. +#[derive(Debug)] +struct CachedAwsCredentialProvider { + /// The underlying AWS credential provider that this cache wraps. + /// This can be any provider implementing `ProvideCredentials` (static, IMDS, ECS, assume role, etc.) + provider: Arc<dyn ProvideCredentials>, + + /// Cache holding the most recently fetched credentials. [CredentialProvider] is required to be + /// Send + Sync, so we have to use Arc + RwLock to make it thread-safe. + cached: Arc<RwLock<Option<aws_credential_types::Credentials>>>, + + /// Time buffer before credential expiry when refresh should be triggered. + /// For example, if set to 5 minutes, credentials will be refreshed when they have + /// 5 minutes or less remaining before expiration. This prevents credential expiry + /// during active operations. + min_ttl: Duration, + + /// The metadata of the credential provider. Only present when running tests. This field is used + /// to assert on the structure of the credential provider. + #[cfg(test)] + metadata: CredentialProviderMetadata, +} + +impl CachedAwsCredentialProvider { + #[allow(unused_variables)] + fn new( + credential_provider: Arc<dyn ProvideCredentials>, + metadata: CredentialProviderMetadata, + min_ttl: Duration, + ) -> Self { + Self { + provider: credential_provider, + cached: Arc::new(RwLock::new(None)), + min_ttl, + #[cfg(test)] + metadata, + } + } + + #[cfg(test)] + fn metadata(&self) -> CredentialProviderMetadata { + self.metadata.clone() + } + + fn fetch_credential(&self) -> Option<aws_credential_types::Credentials> { + let locked = self.cached.read().unwrap(); + locked.as_ref().and_then(|cred| match cred.expiry() { + Some(expiry) => { + if expiry < SystemTime::now() + self.min_ttl { + None + } else { + Some(cred.clone()) + } + } + None => Some(cred.clone()), + }) + } + + async fn refresh_credential(&self) -> object_store::Result<aws_credential_types::Credentials> { + let credentials = self.provider.provide_credentials().await.map_err(|e| { + error!("Failed to retrieve credentials: {:?}", e); + object_store::Error::Generic { + store: "S3", + source: Box::new(e), + } + })?; + *self.cached.write().unwrap() = Some(credentials.clone()); + Ok(credentials) + } +} + +#[async_trait] +impl CredentialProvider for CachedAwsCredentialProvider { + /// The type of credential returned by this provider + type Credential = AwsCredential; + + /// Return a credential + async fn get_credential(&self) -> object_store::Result<Arc<AwsCredential>> { + let credentials = match self.fetch_credential() { + Some(cred) => cred, + None => self.refresh_credential().await?, + }; + Ok(Arc::new(AwsCredential { + key_id: credentials.access_key_id().to_string(), + secret_key: credentials.secret_access_key().to_string(), + token: credentials.session_token().map(|s| s.to_string()), + })) + } +} + +/// A custom AWS credential provider that holds static, pre-configured credentials. +/// +/// This provider is used when the S3 credential configuration specifies static access keys, +/// such as when using Hadoop's `SimpleAWSCredentialsProvider` or `TemporaryAWSCredentialsProvider`. +/// Unlike dynamic credential providers (like IMDS or ECS), this provider returns the same +/// credentials every time without any external API calls. +#[derive(Debug)] +struct StaticCredentialProvider { + is_valid: bool, + cred: Credentials, +} + +impl StaticCredentialProvider { + fn new(is_valid: bool, ak: String, sk: String, token: Option<String>) -> Self { + let mut builder = Credentials::builder() + .access_key_id(ak) + .secret_access_key(sk) + .provider_name("AwsStaticCredentialProvider"); + if let Some(token) = token { + builder = builder.session_token(token); + } + let cred = builder.build(); + Self { is_valid, cred } + } +} + +impl ProvideCredentials for StaticCredentialProvider { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + if self.is_valid { + aws_credential_types::provider::future::ProvideCredentials::ready(Ok(self.cred.clone())) + } else { + aws_credential_types::provider::future::ProvideCredentials::ready(Err( + CredentialsError::not_loaded_no_source(), + )) + } + } +} + +/// Structural representation of credential provider types. It reflects the nested structure of the +/// credential providers, and can be used as blueprint to creating the actual credential providers. +/// We are defining this type because it is hard to assert on the structures of credential providers +/// using the `dyn ProvideCredentials` values directly. Please refer to the test cases for usages of +/// this type. +#[derive(Debug, Clone, PartialEq)] +enum CredentialProviderMetadata { + Default, + Ecs, + Imds, + Environment, + WebIdentity, + Static { + is_valid: bool, + access_key: String, + secret_key: String, + session_token: Option<String>, + }, + AssumeRole { + role_arn: String, + session_name: String, + base_provider_metadata: Box<CredentialProviderMetadata>, + }, + Chain(Vec<CredentialProviderMetadata>), +} + +impl CredentialProviderMetadata { + fn name(&self) -> &'static str { + match self { + CredentialProviderMetadata::Default => "Default", + CredentialProviderMetadata::Ecs => "Ecs", + CredentialProviderMetadata::Imds => "Imds", + CredentialProviderMetadata::Environment => "Environment", + CredentialProviderMetadata::WebIdentity => "WebIdentity", + CredentialProviderMetadata::Static { .. } => "Static", + CredentialProviderMetadata::AssumeRole { .. } => "AssumeRole", + CredentialProviderMetadata::Chain(..) => "Chain", + } + } + + /// Return a simple name for the credential provider. Security sensitive informations are not included. + /// This is useful for logging and debugging. + fn simple_string(&self) -> String { + match self { + CredentialProviderMetadata::Default => "Default".to_string(), + CredentialProviderMetadata::Ecs => "Ecs".to_string(), + CredentialProviderMetadata::Imds => "Imds".to_string(), + CredentialProviderMetadata::Environment => "Environment".to_string(), + CredentialProviderMetadata::WebIdentity => "WebIdentity".to_string(), + CredentialProviderMetadata::Static { is_valid, .. } => { + format!("Static(valid: {})", is_valid) + } + CredentialProviderMetadata::AssumeRole { + role_arn, + session_name, + base_provider_metadata, + } => { + format!( + "AssumeRole(role: {}, session: {}, base: {})", + role_arn, + session_name, + base_provider_metadata.simple_string() + ) + } + CredentialProviderMetadata::Chain(providers) => { + let provider_strings: Vec<String> = + providers.iter().map(|p| p.simple_string()).collect(); + format!("Chain({})", provider_strings.join(" -> ")) + } + } + } +} + +impl CredentialProviderMetadata { + /// Create a credential provider from the metadata. + /// + /// Note: this function is not covered by tests. However, the implementation of this function is + /// quite straightforward and should be easy to verify. + async fn create_credential_provider( + &self, + ) -> Result<Arc<dyn ProvideCredentials>, object_store::Error> { + match self { + CredentialProviderMetadata::Default => { + let config = aws_config::defaults(BehaviorVersion::latest()).load().await; + let credential_provider = + config + .credentials_provider() + .ok_or(object_store::Error::Generic { + store: "S3", + source: "Cannot get default credential provider chain".into(), + })?; + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::Ecs => { + let credential_provider = EcsCredentialsProvider::builder().build(); + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::Imds => { + let credential_provider = ImdsCredentialsProvider::builder().build(); + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::Environment => { + let credential_provider = EnvironmentVariableCredentialsProvider::new(); + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::WebIdentity => { + let credential_provider = WebIdentityTokenCredentialsProvider::builder() + .configure(&ProviderConfig::with_default_region().await) + .build(); + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::Static { + is_valid, + access_key, + secret_key, + session_token, + } => { + let credential_provider = StaticCredentialProvider::new( + *is_valid, + access_key.clone(), + secret_key.clone(), + session_token.clone(), + ); + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::AssumeRole { + role_arn, + session_name, + base_provider_metadata, + } => { + let base_provider = + Box::pin(base_provider_metadata.create_credential_provider()).await?; + let credential_provider = AssumeRoleProvider::builder(role_arn) + .session_name(session_name) + .build_from_provider(base_provider) + .await; + Ok(Arc::new(credential_provider)) + } + CredentialProviderMetadata::Chain(metadata_vec) => { + if metadata_vec.is_empty() { + return Err(object_store::Error::Generic { + store: "S3", + source: "Cannot create credential provider chain with empty providers" + .into(), + }); + } + let mut chained_provider = CredentialsProviderChain::first_try( + metadata_vec[0].name(), + Box::pin(metadata_vec[0].create_credential_provider()).await?, + ); + for metadata in metadata_vec[1..].iter() { + chained_provider = chained_provider.or_else( + metadata.name(), + Box::pin(metadata.create_credential_provider()).await?, + ); + } + Ok(Arc::new(chained_provider)) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicI32, Ordering}; + + use super::*; + + /// Test configuration builder for easier setup Hadoop configurations + #[derive(Debug, Default)] + struct TestConfigBuilder { + configs: HashMap<String, String>, + } + + impl TestConfigBuilder { + fn new() -> Self { + Self::default() + } + + fn with_region(mut self, region: &str) -> Self { + self.configs + .insert("fs.s3a.endpoint.region".to_string(), region.to_string()); + self + } + + fn with_credential_provider(mut self, provider: &str) -> Self { + self.configs.insert( + "fs.s3a.aws.credentials.provider".to_string(), + provider.to_string(), + ); + self + } + + fn with_bucket_credential_provider(mut self, bucket: &str, provider: &str) -> Self { + self.configs.insert( + format!("fs.s3a.bucket.{}.aws.credentials.provider", bucket), + provider.to_string(), + ); + self + } + + fn with_access_key(mut self, key: &str) -> Self { + self.configs + .insert("fs.s3a.access.key".to_string(), key.to_string()); + self + } + + fn with_secret_key(mut self, key: &str) -> Self { + self.configs + .insert("fs.s3a.secret.key".to_string(), key.to_string()); + self + } + + fn with_session_token(mut self, token: &str) -> Self { + self.configs + .insert("fs.s3a.session.token".to_string(), token.to_string()); + self + } + + fn with_bucket_access_key(mut self, bucket: &str, key: &str) -> Self { + self.configs.insert( + format!("fs.s3a.bucket.{}.access.key", bucket), + key.to_string(), + ); + self + } + + fn with_bucket_secret_key(mut self, bucket: &str, key: &str) -> Self { + self.configs.insert( + format!("fs.s3a.bucket.{}.secret.key", bucket), + key.to_string(), + ); + self + } + + fn with_bucket_session_token(mut self, bucket: &str, token: &str) -> Self { + self.configs.insert( + format!("fs.s3a.bucket.{}.session.token", bucket), + token.to_string(), + ); + self + } + + fn with_assume_role_arn(mut self, arn: &str) -> Self { + self.configs + .insert("fs.s3a.assumed.role.arn".to_string(), arn.to_string()); + self + } + + fn with_assume_role_session_name(mut self, name: &str) -> Self { + self.configs.insert( + "fs.s3a.assumed.role.session.name".to_string(), + name.to_string(), + ); + self + } + + fn with_assume_role_credentials_provider(mut self, provider: &str) -> Self { + self.configs.insert( + "fs.s3a.assumed.role.credentials.provider".to_string(), + provider.to_string(), + ); + self + } + + fn build(self) -> HashMap<String, String> { + self.configs + } + } + + #[test] + #[cfg_attr(miri, ignore)] // AWS credential providers and object_store call foreign functions + fn test_create_store() { + let url = Url::parse("s3a://test_bucket/comet/spark-warehouse/part-00000.snappy.parquet") + .unwrap(); + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ANONYMOUS) + .with_region("us-east-1") + .build(); + let (_object_store, path) = create_store(&url, &configs, Duration::from_secs(300)).unwrap(); + assert_eq!( + path, + Path::from("/comet/spark-warehouse/part-00000.snappy.parquet") + ); + } + + #[test] + fn test_get_config_trimmed() { + let configs = TestConfigBuilder::new() + .with_access_key("test_key") + .with_secret_key(" \n test_secret_key\n \n") + .with_session_token(" \n test_session_token\n \n") + .with_bucket_access_key("test-bucket", "test_bucket_key") + .with_bucket_secret_key("test-bucket", " \n test_bucket_secret_key\n \n") + .with_bucket_session_token("test-bucket", " \n test_bucket_session_token\n \n") + .build(); + + // bucket-specific keys + let access_key = get_config_trimmed(&configs, "test-bucket", "access.key"); + assert_eq!(access_key, Some("test_bucket_key")); + let secret_key = get_config_trimmed(&configs, "test-bucket", "secret.key"); + assert_eq!(secret_key, Some("test_bucket_secret_key")); + let session_token = get_config_trimmed(&configs, "test-bucket", "session.token"); + assert_eq!(session_token, Some("test_bucket_session_token")); + + // global keys + let access_key = get_config_trimmed(&configs, "test-bucket-2", "access.key"); + assert_eq!(access_key, Some("test_key")); + let secret_key = get_config_trimmed(&configs, "test-bucket-2", "secret.key"); + assert_eq!(secret_key, Some("test_secret_key")); + let session_token = get_config_trimmed(&configs, "test-bucket-2", "session.token"); + assert_eq!(session_token, Some("test_session_token")); + } + + #[test] + fn test_parse_credential_provider_names() { + let credential_provider_names = parse_credential_provider_names(""); + assert!(credential_provider_names.is_empty()); + + let credential_provider_names = parse_credential_provider_names(HADOOP_ANONYMOUS); + assert_eq!(credential_provider_names, vec![HADOOP_ANONYMOUS]); + + let aws_credential_provider_names = format!( + "{},{},{}", + HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1 + ); + let credential_provider_names = + parse_credential_provider_names(&aws_credential_provider_names); + assert_eq!( + credential_provider_names, + vec![HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1] + ); + + let aws_credential_provider_names = format!( + " {}, {},, {},", + HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1 + ); + let credential_provider_names = + parse_credential_provider_names(&aws_credential_provider_names); + assert_eq!( + credential_provider_names, + vec![HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1] + ); + + let aws_credential_provider_names = format!( + "\n {},\n {},\n , \n {},\n", + HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1 + ); + let credential_provider_names = + parse_credential_provider_names(&aws_credential_provider_names); + assert_eq!( + credential_provider_names, + vec![HADOOP_ANONYMOUS, AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1] + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_default_credential_provider() { + let configs0 = TestConfigBuilder::new().build(); + let configs1 = TestConfigBuilder::new() + .with_credential_provider("") + .build(); + let configs2 = TestConfigBuilder::new() + .with_credential_provider("\n ,") + .build(); + + for configs in [configs0, configs1, configs2] { + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for default config" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Default + ); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_anonymous_credential_provider() { + for provider_name in [HADOOP_ANONYMOUS, AWS_ANONYMOUS, AWS_ANONYMOUS_V1] { + let configs = TestConfigBuilder::new() + .with_credential_provider(provider_name) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_none(), "Anonymous provider should return None"); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_mixed_anonymous_and_other_providers_error() { + let configs = TestConfigBuilder::new() + .with_credential_provider(&format!("{},{}", HADOOP_ANONYMOUS, AWS_ENVIRONMENT)) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)).await; + assert!( + result.is_err(), + "Should error when mixing anonymous with other providers" + ); + + if let Err(e) = result { + assert!(e + .to_string() + .contains("Anonymous credential provider cannot be mixed")); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_simple_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .with_session_token("test_session_token") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for simple credentials" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_temporary_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_TEMPORARY) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .with_session_token("test_session_token") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for temporary credentials" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: Some("test_session_token".to_string()) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_missing_access_key() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_secret_key("test_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return an invalid credential provider when access key is missing" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: false, + access_key: "".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_missing_secret_key() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_access_key("test_access_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return an invalid credential provider when secret key is missing" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: false, + access_key: "test_access_key".to_string(), + secret_key: "".to_string(), + session_token: None + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_missing_session_token_for_temporary() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_TEMPORARY) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return an invalid credential provider when session token is missing" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: false, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_bucket_specific_configuration() { + let configs = TestConfigBuilder::new() + .with_bucket_credential_provider("specific-bucket", HADOOP_SIMPLE) + .with_bucket_access_key("specific-bucket", "bucket_access_key") + .with_bucket_secret_key("specific-bucket", "bucket_secret_key") + .build(); + + let result = + build_credential_provider(&configs, "specific-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for bucket-specific config" + ); + + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .with_bucket_credential_provider("specific-bucket", HADOOP_TEMPORARY) + .with_bucket_access_key("specific-bucket", "bucket_access_key") + .with_bucket_secret_key("specific-bucket", "bucket_secret_key") + .with_bucket_session_token("specific-bucket", "bucket_session_token") + .with_bucket_credential_provider("specific-bucket-2", HADOOP_TEMPORARY) + .with_bucket_access_key("specific-bucket-2", "bucket_access_key_2") + .with_bucket_secret_key("specific-bucket-2", "bucket_secret_key_2") + .with_bucket_session_token("specific-bucket-2", "bucket_session_token_2") + .build(); + + let result = + build_credential_provider(&configs, "specific-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for bucket-specific config" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "bucket_access_key".to_string(), + secret_key: "bucket_secret_key".to_string(), + session_token: Some("bucket_session_token".to_string()) + } + ); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for default config" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None + } + ); + + let result = + build_credential_provider(&configs, "specific-bucket-2", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for bucket-specific config" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "bucket_access_key_2".to_string(), + secret_key: "bucket_secret_key_2".to_string(), + session_token: Some("bucket_session_token_2".to_string()) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_assume_role_arn("arn:aws:iam::123456789012:role/test-role") + .with_assume_role_session_name("test-session") + .with_access_key("base_access_key") + .with_secret_key("base_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for assume role" + ); + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/test-role".to_string(), + session_name: "test-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "base_access_key".to_string(), + secret_key: "base_secret_key".to_string(), + session_token: None + }, + CredentialProviderMetadata::Environment, + ])) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_missing_arn_error() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_access_key("base_access_key") + .with_secret_key("base_secret_key") + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)).await; + assert!( + result.is_err(), + "Should error when assume role ARN is missing" + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_unsupported_credential_provider_error() { + let configs = TestConfigBuilder::new() + .with_credential_provider("unsupported.provider.Class") + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)).await; + assert!( + result.is_err(), + "Should error for unsupported credential provider" + ); + + if let Err(e) = result { + assert!(e.to_string().contains("Unsupported credential provider")); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_environment_credential_provider() { + for provider_name in [AWS_ENVIRONMENT, AWS_ENVIRONMENT_V1] { + let configs = TestConfigBuilder::new() + .with_credential_provider(provider_name) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap().metadata(); + assert_eq!(test_provider, CredentialProviderMetadata::Environment); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_ecs_credential_provider() { + for provider_name in [ + AWS_CONTAINER_CREDENTIALS, + AWS_CONTAINER_CREDENTIALS_V1, + AWS_EC2_CONTAINER_CREDENTIALS, + ] { + let configs = TestConfigBuilder::new() + .with_credential_provider(provider_name) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap().metadata(); + assert_eq!(test_provider, CredentialProviderMetadata::Ecs); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_imds_credential_provider() { + for provider_name in [AWS_INSTANCE_PROFILE, AWS_INSTANCE_PROFILE_V1] { + let configs = TestConfigBuilder::new() + .with_credential_provider(provider_name) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap().metadata(); + assert_eq!(test_provider, CredentialProviderMetadata::Imds); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_web_identity_credential_provider() { + for provider_name in [AWS_WEB_IDENTITY, AWS_WEB_IDENTITY_V1] { + let configs = TestConfigBuilder::new() + .with_credential_provider(provider_name) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap().metadata(); + assert_eq!(test_provider, CredentialProviderMetadata::WebIdentity); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_hadoop_iam_instance_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_IAM_INSTANCE) + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap().metadata(); + assert_eq!( + test_provider, + CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Ecs, + CredentialProviderMetadata::Imds + ]) + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_chained_credential_providers() { + // Test three providers in chain: Environment -> IMDS -> ECS + let configs = TestConfigBuilder::new() + .with_credential_provider(&format!( + "{},{},{}", + AWS_ENVIRONMENT, AWS_INSTANCE_PROFILE, AWS_CONTAINER_CREDENTIALS + )) + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for complex chain" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Environment, + CredentialProviderMetadata::Imds, + CredentialProviderMetadata::Ecs + ]) + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_static_environment_web_identity_chain() { + // Test chaining static credentials -> environment -> web identity + let configs = TestConfigBuilder::new() + .with_credential_provider(&format!( + "{},{},{}", + HADOOP_SIMPLE, AWS_ENVIRONMENT, AWS_WEB_IDENTITY + )) + .with_access_key("chain_access_key") + .with_secret_key("chain_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return a credential provider for static+env+web chain" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "chain_access_key".to_string(), + secret_key: "chain_secret_key".to_string(), + session_token: None + }, + CredentialProviderMetadata::Environment, + CredentialProviderMetadata::WebIdentity + ]) + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_with_static_base_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_assume_role_arn("arn:aws:iam::123456789012:role/test-role") + .with_assume_role_session_name("static-base-session") + .with_assume_role_credentials_provider(HADOOP_TEMPORARY) + .with_access_key("base_static_access_key") + .with_secret_key("base_static_secret_key") + .with_session_token("base_static_session_token") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return assume role provider with static base" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/test-role".to_string(), + session_name: "static-base-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::Static { + is_valid: true, + access_key: "base_static_access_key".to_string(), + secret_key: "base_static_secret_key".to_string(), + session_token: Some("base_static_session_token".to_string()) + }) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_with_web_identity_base_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_assume_role_arn("arn:aws:iam::123456789012:role/web-identity-role") + .with_assume_role_session_name("web-identity-session") + .with_assume_role_credentials_provider(AWS_WEB_IDENTITY) + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return assume role provider with web identity base" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/web-identity-role".to_string(), + session_name: "web-identity-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::WebIdentity) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_with_chained_base_providers() { + // Test assume role with multiple base providers: Static -> Environment -> IMDS + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_assume_role_arn("arn:aws:iam::123456789012:role/chained-role") + .with_assume_role_session_name("chained-base-session") + .with_assume_role_credentials_provider(&format!( + "{},{},{}", + HADOOP_SIMPLE, AWS_ENVIRONMENT, AWS_INSTANCE_PROFILE + )) + .with_access_key("chained_base_access_key") + .with_secret_key("chained_base_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return assume role provider with chained base" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/chained-role".to_string(), + session_name: "chained-base-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "chained_base_access_key".to_string(), + secret_key: "chained_base_secret_key".to_string(), + session_token: None + }, + CredentialProviderMetadata::Environment, + CredentialProviderMetadata::Imds + ])) + } + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_chained_with_other_providers() { + // Test assume role as first provider in a chain, followed by environment and IMDS + let configs = TestConfigBuilder::new() + .with_credential_provider(&format!( + " {}\n, {}\n", + HADOOP_ASSUMED_ROLE, AWS_INSTANCE_PROFILE + )) + .with_assume_role_arn("arn:aws:iam::123456789012:role/first-in-chain") + .with_assume_role_session_name("first-chain-session") + .with_assume_role_credentials_provider(&format!( + " {}\n, {}\n, {}\n", + AWS_WEB_IDENTITY, HADOOP_TEMPORARY, AWS_ENVIRONMENT + )) + .with_access_key("assume_role_base_key") + .with_secret_key("assume_role_base_secret") + .with_session_token("assume_role_base_token") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!( + result.is_some(), + "Should return chained provider with assume role first" + ); + + assert_eq!( + result.unwrap().metadata(), + CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/first-in-chain".to_string(), + session_name: "first-chain-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::WebIdentity, + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "assume_role_base_key".to_string(), + secret_key: "assume_role_base_secret".to_string(), + session_token: Some("assume_role_base_token".to_string()) + }, + CredentialProviderMetadata::Environment, + ])) + }, + CredentialProviderMetadata::Imds + ]) + ); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_assume_role_with_anonymous_base_provider_error() { + // Test that assume role with anonymous base provider fails + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_ASSUMED_ROLE) + .with_assume_role_arn("arn:aws:iam::123456789012:role/should-fail") + .with_assume_role_session_name("should-fail-session") + .with_assume_role_credentials_provider(HADOOP_ANONYMOUS) + .build(); + + let result = + build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)).await; + assert!( + result.is_err(), + "Should error when assume role uses anonymous base provider" + ); + + if let Err(e) = result { + assert!(e.to_string().contains( + "Anonymous credential provider cannot be used as assumed role credential provider" + )); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_get_credential_from_static_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap(); + let credential = test_provider.get_credential().await.unwrap(); + assert_eq!(credential.key_id, "test_access_key"); + assert_eq!(credential.secret_key, "test_secret_key"); + assert_eq!(credential.token, None); + + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_TEMPORARY) + .with_access_key("test_access_key_2") + .with_secret_key("test_secret_key_2") + .with_session_token("test_session_token_2") + .build(); + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap(); + let credential = test_provider.get_credential().await.unwrap(); + assert_eq!(credential.key_id, "test_access_key_2"); + assert_eq!(credential.secret_key, "test_secret_key_2"); + assert_eq!(credential.token, Some("test_session_token_2".to_string())); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_get_credential_from_invalid_static_credential_provider() { + let configs = TestConfigBuilder::new() + .with_credential_provider(HADOOP_SIMPLE) + .with_access_key("test_access_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + let test_provider = result.unwrap(); + let result = test_provider.get_credential().await; + assert!(result.is_err(), "Should return an error when getting credential from invalid static credential provider"); + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_invalid_static_credential_provider_should_not_prevent_other_providers_from_working( + ) { + let configs = TestConfigBuilder::new() + .with_credential_provider(&format!("{},{}", HADOOP_TEMPORARY, HADOOP_SIMPLE)) + .with_access_key("test_access_key") + .with_secret_key("test_secret_key") + .build(); + + let result = build_credential_provider(&configs, "test-bucket", Duration::from_secs(300)) + .await + .unwrap(); + assert!(result.is_some(), "Should return a credential provider"); + + assert_eq!( + result.as_ref().unwrap().metadata(), + CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Static { + is_valid: false, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None, + }, + CredentialProviderMetadata::Static { + is_valid: true, + access_key: "test_access_key".to_string(), + secret_key: "test_secret_key".to_string(), + session_token: None, + } + ]) + ); + + let test_provider = result.unwrap(); + + for _ in 0..10 { + let credential = test_provider.get_credential().await.unwrap(); + assert_eq!(credential.key_id, "test_access_key"); + assert_eq!(credential.secret_key, "test_secret_key"); + } + } + + #[derive(Debug)] + struct MockAwsCredentialProvider { + counter: AtomicI32, + } + + impl ProvideCredentials for MockAwsCredentialProvider { + fn provide_credentials<'a>( + &'a self, + ) -> aws_credential_types::provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + let cnt = self.counter.fetch_add(1, Ordering::SeqCst); + let cred = Credentials::builder() + .access_key_id(format!("test_access_key_{}", cnt)) + .secret_access_key(format!("test_secret_key_{}", cnt)) + .expiry(SystemTime::now() + Duration::from_secs(60)) + .provider_name("mock_provider") + .build(); + aws_credential_types::provider::future::ProvideCredentials::ready(Ok(cred)) + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_cached_credential_provider_refresh_credential() { + let provider = Arc::new(MockAwsCredentialProvider { + counter: AtomicI32::new(0), + }); + + // 60 seconds before expiry, the credential is always refreshed + let cached_provider = CachedAwsCredentialProvider::new( + provider, + CredentialProviderMetadata::Default, + Duration::from_secs(60), + ); + for k in 0..3 { + let credential = cached_provider.get_credential().await.unwrap(); + assert_eq!(credential.key_id, format!("test_access_key_{}", k)); + assert_eq!(credential.secret_key, format!("test_secret_key_{}", k)); + } + } + + #[tokio::test] + #[cfg_attr(miri, ignore)] // AWS credential providers call foreign functions + async fn test_cached_credential_provider_cache_credential() { + let provider = Arc::new(MockAwsCredentialProvider { + counter: AtomicI32::new(0), + }); + + // 10 seconds before expiry, the credential is not refreshed + let cached_provider = CachedAwsCredentialProvider::new( + provider, + CredentialProviderMetadata::Default, + Duration::from_secs(10), + ); + for _ in 0..3 { + let credential = cached_provider.get_credential().await.unwrap(); + assert_eq!(credential.key_id, "test_access_key_0"); + assert_eq!(credential.secret_key, "test_secret_key_0"); + } + } + + #[test] + fn test_extract_s3_config_options() { + let mut configs = HashMap::new(); + configs.insert( + "fs.s3a.endpoint.region".to_string(), + "ap-northeast-1".to_string(), + ); + configs.insert( + "fs.s3a.requester.pays.enabled".to_string(), + "true".to_string(), + ); + let s3_configs = extract_s3_config_options(&configs, "test-bucket"); + assert_eq!( + s3_configs.get(&AmazonS3ConfigKey::Region), + Some(&"ap-northeast-1".to_string()) + ); + assert_eq!( + s3_configs.get(&AmazonS3ConfigKey::RequestPayer), + Some(&"true".to_string()) + ); + } + + #[test] + fn test_extract_s3_config_custom_endpoint() { + let cases = vec![ + ("custom.endpoint.com", "https://custom.endpoint.com"), + ("https://custom.endpoint.com", "https://custom.endpoint.com"), + ( + "https://custom.endpoint.com/path/to/resource", + "https://custom.endpoint.com/path/to/resource", + ), + ]; + for (endpoint, configured_endpoint) in cases { + let mut configs = HashMap::new(); + configs.insert("fs.s3a.endpoint".to_string(), endpoint.to_string()); + let s3_configs = extract_s3_config_options(&configs, "test-bucket"); + assert_eq!( + s3_configs.get(&AmazonS3ConfigKey::Endpoint), + Some(&configured_endpoint.to_string()) + ); + } + } + + #[test] + fn test_extract_s3_config_custom_endpoint_with_virtual_hosted_style() { + let cases = vec![ + ( + "custom.endpoint.com", + "https://custom.endpoint.com/test-bucket", + ), + ( + "https://custom.endpoint.com", + "https://custom.endpoint.com/test-bucket", + ), + ( + "https://custom.endpoint.com/", + "https://custom.endpoint.com/test-bucket", + ), + ( + "https://custom.endpoint.com/path/to/resource", + "https://custom.endpoint.com/path/to/resource/test-bucket", + ), + ( + "https://custom.endpoint.com/path/to/resource/", + "https://custom.endpoint.com/path/to/resource/test-bucket", + ), + ]; + for (endpoint, configured_endpoint) in cases { + let mut configs = HashMap::new(); + configs.insert("fs.s3a.endpoint".to_string(), endpoint.to_string()); + configs.insert("fs.s3a.path.style.access".to_string(), "true".to_string()); + let s3_configs = extract_s3_config_options(&configs, "test-bucket"); + assert_eq!( + s3_configs.get(&AmazonS3ConfigKey::Endpoint), + Some(&configured_endpoint.to_string()) + ); + } + } + + #[test] + fn test_extract_s3_config_ignore_default_endpoint() { + let mut configs = HashMap::new(); + configs.insert( + "fs.s3a.endpoint".to_string(), + "s3.amazonaws.com".to_string(), + ); + let s3_configs = extract_s3_config_options(&configs, "test-bucket"); + assert!(s3_configs.is_empty()); + + configs.insert("fs.s3a.endpoint".to_string(), "".to_string()); + let s3_configs = extract_s3_config_options(&configs, "test-bucket"); + assert!(s3_configs.is_empty()); + } + + #[test] + fn test_credential_provider_metadata_simple_string() { + // Test Static provider + let static_metadata = CredentialProviderMetadata::Static { + is_valid: true, + access_key: "sensitive_key".to_string(), + secret_key: "sensitive_secret".to_string(), + session_token: Some("sensitive_token".to_string()), + }; + assert_eq!(static_metadata.simple_string(), "Static(valid: true)"); + + // Test AssumeRole provider + let assume_role_metadata = CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/test-role".to_string(), + session_name: "test-session".to_string(), + base_provider_metadata: Box::new(CredentialProviderMetadata::Environment), + }; + assert_eq!( + assume_role_metadata.simple_string(), + "AssumeRole(role: arn:aws:iam::123456789012:role/test-role, session: test-session, base: Environment)" + ); + + // Test Chain provider + let chain_metadata = CredentialProviderMetadata::Chain(vec![ + CredentialProviderMetadata::Static { + is_valid: false, + access_key: "key1".to_string(), + secret_key: "secret1".to_string(), + session_token: None, + }, + CredentialProviderMetadata::Environment, + CredentialProviderMetadata::Imds, + ]); + assert_eq!( + chain_metadata.simple_string(), + "Chain(Static(valid: false) -> Environment -> Imds)" + ); + + // Test nested AssumeRole with Chain base + let nested_metadata = CredentialProviderMetadata::AssumeRole { + role_arn: "arn:aws:iam::123456789012:role/nested-role".to_string(), + session_name: "nested-session".to_string(), + base_provider_metadata: Box::new(chain_metadata), + }; + assert_eq!( + nested_metadata.simple_string(), + "AssumeRole(role: arn:aws:iam::123456789012:role/nested-role, session: nested-session, base: Chain(Static(valid: false) -> Environment -> Imds))" + ); + } +} diff --git a/native/core/src/parquet/parquet_support.rs b/native/core/src/parquet/parquet_support.rs index 4067afaea..4e6f8a172 100644 --- a/native/core/src/parquet/parquet_support.rs +++ b/native/core/src/parquet/parquet_support.rs @@ -37,9 +37,12 @@ use datafusion_comet_spark_expr::EvalMode; use object_store::path::Path; use object_store::{parse_url, ObjectStore}; use std::collections::HashMap; +use std::time::Duration; use std::{fmt::Debug, hash::Hash, sync::Arc}; use url::Url; +use super::objectstore; + static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); static PARQUET_OPTIONS: CastOptions = CastOptions { @@ -343,6 +346,16 @@ fn parse_hdfs_url(_url: &Url) -> Result<(Box<dyn ObjectStore>, Path), object_sto pub(crate) fn prepare_object_store( runtime_env: Arc<RuntimeEnv>, url: String, +) -> Result<(ObjectStoreUrl, Path), ExecutionError> { + prepare_object_store_with_configs(runtime_env, url, &HashMap::new()) +} + +/// Parses the url, registers the object store with configurations, and returns a tuple of the object store url +/// and object store path +pub(crate) fn prepare_object_store_with_configs( + runtime_env: Arc<RuntimeEnv>, + url: String, + object_store_configs: &HashMap<String, String>, ) -> Result<(ObjectStoreUrl, Path), ExecutionError> { let mut url = Url::parse(url.as_str()) .map_err(|e| ExecutionError::GeneralError(format!("Error parsing URL {url}: {e}")))?; @@ -361,6 +374,8 @@ pub(crate) fn prepare_object_store( let (object_store, object_store_path): (Box<dyn ObjectStore>, Path) = if scheme == "hdfs" { parse_hdfs_url(&url) + } else if scheme == "s3" { + objectstore::s3::create_store(&url, object_store_configs, Duration::from_secs(300)) } else { parse_url(&url) } @@ -386,19 +401,14 @@ mod tests { use crate::execution::operators::ExecutionError; let local_file_system_url = "file:///comet/spark-warehouse/part-00000.snappy.parquet"; - let s3_url = "s3a://test_bucket/comet/spark-warehouse/part-00000.snappy.parquet"; let hdfs_url = "hdfs://localhost:8020/comet/spark-warehouse/part-00000.snappy.parquet"; - let all_urls = [local_file_system_url, s3_url, hdfs_url]; + let all_urls = [local_file_system_url, hdfs_url]; let expected: Vec<Result<(ObjectStoreUrl, Path), ExecutionError>> = vec![ Ok(( ObjectStoreUrl::parse("file://").unwrap(), Path::from("/comet/spark-warehouse/part-00000.snappy.parquet"), )), - Ok(( - ObjectStoreUrl::parse("s3://test_bucket").unwrap(), - Path::from("/comet/spark-warehouse/part-00000.snappy.parquet"), - )), Err(ExecutionError::GeneralError( "Generic HadoopFileSystem error: Hdfs support is not enabled in this build" .parse() diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9a41f977a..7ccce21a2 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -94,6 +94,13 @@ message NativeScan { repeated spark.spark_expression.Expr default_values = 10; repeated int64 default_values_indexes = 11; bool case_sensitive = 12; + // Options for configuring object stores such as AWS S3, GCS, etc. The key-value pairs are taken + // from Hadoop configuration for compatibility with Hadoop FileSystem implementations of object + // stores. + // The configuration values have hadoop. or spark.hadoop. prefix trimmed. For instance, the + // configuration value "spark.hadoop.fs.s3a.access.key" will be stored as "fs.s3a.access.key" in + // the map. + map<string, string> object_store_options = 13; } message Projection { diff --git a/pom.xml b/pom.xml index f1514c943..d189ed25b 100644 --- a/pom.xml +++ b/pom.xml @@ -84,6 +84,8 @@ under the License. <semanticdb.version>4.8.8</semanticdb.version> <slf4j.version>2.0.7</slf4j.version> <guava.version>33.2.1-jre</guava.version> + <testcontainers.version>1.21.0</testcontainers.version> + <amazon-awssdk-v2.version>2.31.51</amazon-awssdk-v2.version> <jni.dir>${project.basedir}/../native/target/debug</jni.dir> <platform>darwin</platform> <arch>x86_64</arch> @@ -453,6 +455,20 @@ under the License. </exclusions> </dependency> + <!-- TestContainers for testing reading Parquet on S3 --> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>minio</artifactId> + <version>${testcontainers.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>software.amazon.awssdk</groupId> + <artifactId>s3</artifactId> + <version>${amazon-awssdk-v2.version}</version> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.codehaus.jackson</groupId> <artifactId>jackson-mapper-asl</artifactId> diff --git a/spark/pom.xml b/spark/pom.xml index 9a296af85..95ea6971b 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -165,6 +165,14 @@ under the License. </exclusion> </exclusions> </dependency> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>minio</artifactId> + </dependency> + <dependency> + <groupId>software.amazon.awssdk</groupId> + <artifactId>s3</artifactId> + </dependency> </dependencies> <build> diff --git a/spark/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala b/spark/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala new file mode 100644 index 000000000..aebe31ed1 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/objectstore/NativeConfig.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.objectstore + +import java.net.URI +import java.util.Locale + +import org.apache.hadoop.conf.Configuration + +object NativeConfig { + + private val objectStoreConfigPrefixes = Map( + // Amazon S3 configurations + "s3" -> Seq("fs.s3a."), + "s3a" -> Seq("fs.s3a."), + // Google Cloud Storage configurations + "gs" -> Seq("fs.gs."), + // Azure Blob Storage configurations (can use both prefixes) + "wasb" -> Seq("fs.azure.", "fs.wasb."), + "wasbs" -> Seq("fs.azure.", "fs.wasb."), + // Azure Data Lake Storage Gen2 configurations + "abfs" -> Seq("fs.abfs."), + // Azure Data Lake Storage Gen2 secure configurations (can use both prefixes) + "abfss" -> Seq("fs.abfss.", "fs.abfs.")) + + /** + * Extract object store configurations from Hadoop configuration for native DataFusion usage. + * This includes S3, GCS, Azure and other cloud storage configurations. + * + * This method extracts all configurations with supported prefixes, automatically capturing both + * global configurations (e.g., fs.s3a.access.key) and per-bucket configurations (e.g., + * fs.s3a.bucket.{bucket-name}.access.key). The native code will prioritize per-bucket + * configurations over global ones when both are present. + * + * The configurations are passed to the native code which uses object_store's parse_url_opts for + * consistent and standardized cloud storage support across all providers. + */ + def extractObjectStoreOptions(hadoopConf: Configuration, uri: URI): Map[String, String] = { + val scheme = uri.getScheme.toLowerCase(Locale.ROOT) + + // Get prefixes for this scheme, return early if none found + val prefixes = objectStoreConfigPrefixes.get(scheme) + if (prefixes.isEmpty) { + return Map.empty[String, String] + } + + import scala.collection.JavaConverters._ + + // Extract all configurations that match the object store prefixes + val options = scala.collection.mutable.Map[String, String]() + hadoopConf.iterator().asScala.foreach { entry => + val key = entry.getKey + val value = entry.getValue + // Check if key starts with any of the prefixes for this scheme + if (prefixes.get.exists(prefix => key.startsWith(prefix))) { + options(key) = value + } + } + + options.toMap + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7267852b9..f7060d8a1 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD} +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.v2.{DataSourceRDD, DataSourceRDDPartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} @@ -51,6 +52,7 @@ import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.{isCometScan, withInfo} import org.apache.comet.expressions._ +import org.apache.comet.objectstore.NativeConfig import org.apache.comet.serde.ExprOuterClass.{AggExpr, DataType => ProtoDataType, Expr, ScalarFunc} import org.apache.comet.serde.ExprOuterClass.DataType._ import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, BuildSide, JoinType, Operator} @@ -2241,12 +2243,16 @@ object QueryPlanSerde extends Logging with CometExprShim { } // TODO: modify CometNativeScan to generate the file partitions without instantiating RDD. + var firstPartition: Option[PartitionedFile] = None scan.inputRDD match { case rdd: DataSourceRDD => val partitions = rdd.partitions partitions.foreach(p => { val inputPartitions = p.asInstanceOf[DataSourceRDDPartition].inputPartitions inputPartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.asInstanceOf[FilePartition].files.headOption + } partition2Proto( partition.asInstanceOf[FilePartition], nativeScanBuilder, @@ -2255,6 +2261,9 @@ object QueryPlanSerde extends Logging with CometExprShim { }) case rdd: FileScanRDD => rdd.filePartitions.foreach(partition => { + if (firstPartition.isEmpty) { + firstPartition = partition.files.headOption + } partition2Proto(partition, nativeScanBuilder, scan.relation.partitionSchema) }) case _ => @@ -2286,6 +2295,17 @@ object QueryPlanSerde extends Logging with CometExprShim { nativeScanBuilder.setSessionTimezone(conf.getConfString("spark.sql.session.timeZone")) nativeScanBuilder.setCaseSensitive(conf.getConf[Boolean](SQLConf.CASE_SENSITIVE)) + // Collect S3/cloud storage configurations + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + firstPartition.foreach { partitionFile => + val objectStoreOptions = + NativeConfig.extractObjectStoreOptions(hadoopConf, partitionFile.pathUri) + objectStoreOptions.foreach { case (key, value) => + nativeScanBuilder.putObjectStoreOptions(key, value) + } + } + Some(result.setNativeScan(nativeScanBuilder).build()) } else { diff --git a/spark/src/test/scala/org/apache/comet/objectstore/NativeConfigSuite.scala b/spark/src/test/scala/org/apache/comet/objectstore/NativeConfigSuite.scala new file mode 100644 index 000000000..7ba40192f --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/objectstore/NativeConfigSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.objectstore + +import java.net.URI + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import org.apache.hadoop.conf.Configuration + +class NativeConfigSuite extends AnyFunSuite with Matchers { + + test("extractObjectStoreOptions - multiple cloud provider configurations") { + val hadoopConf = new Configuration() + // S3A configs + hadoopConf.set("fs.s3a.access.key", "s3-access-key") + hadoopConf.set("fs.s3a.secret.key", "s3-secret-key") + hadoopConf.set("fs.s3a.endpoint.region", "us-east-1") + hadoopConf.set("fs.s3a.bucket.special-bucket.access.key", "special-access-key") + hadoopConf.set("fs.s3a.bucket.special-bucket.endpoint.region", "eu-central-1") + + // GCS configs + hadoopConf.set("fs.gs.project.id", "gcp-project") + + // Azure configs + hadoopConf.set("fs.azure.account.key.testaccount.blob.core.windows.net", "azure-key") + + // Should extract s3 options + Seq("s3a://test-bucket/test-object", "s3://test-bucket/test-object").foreach { path => + val options = NativeConfig.extractObjectStoreOptions(hadoopConf, new URI(path)) + assert(options("fs.s3a.access.key") == "s3-access-key") + assert(options("fs.s3a.secret.key") == "s3-secret-key") + assert(options("fs.s3a.endpoint.region") == "us-east-1") + assert(options("fs.s3a.bucket.special-bucket.access.key") == "special-access-key") + assert(options("fs.s3a.bucket.special-bucket.endpoint.region") == "eu-central-1") + assert(!options.contains("fs.gs.project.id")) + } + val gsOptions = + NativeConfig.extractObjectStoreOptions(hadoopConf, new URI("gs://test-bucket/test-object")) + assert(gsOptions("fs.gs.project.id") == "gcp-project") + assert(!gsOptions.contains("fs.s3a.access.key")) + + val azureOptions = NativeConfig.extractObjectStoreOptions( + hadoopConf, + new URI("wasb://test-bucket/test-object")) + assert(azureOptions("fs.azure.account.key.testaccount.blob.core.windows.net") == "azure-key") + assert(!azureOptions.contains("fs.s3a.access.key")) + + // Unsupported scheme should return empty options + val unsupportedOptions = NativeConfig.extractObjectStoreOptions( + hadoopConf, + new URI("unsupported://test-bucket/test-object")) + assert(unsupportedOptions.isEmpty, "Unsupported scheme should return empty options") + } +} diff --git a/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala new file mode 100644 index 000000000..ff5a78243 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/parquet/ParquetReadFromS3Suite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.parquet + +import java.net.URI + +import scala.util.Try + +import org.testcontainers.containers.MinIOContainer +import org.testcontainers.utility.DockerImageName + +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.comet.CometNativeScanExec +import org.apache.spark.sql.comet.CometScanExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.{col, sum} + +import org.apache.comet.CometConf.SCAN_NATIVE_ICEBERG_COMPAT + +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider +import software.amazon.awssdk.services.s3.S3Client +import software.amazon.awssdk.services.s3.model.CreateBucketRequest +import software.amazon.awssdk.services.s3.model.HeadBucketRequest + +class ParquetReadFromS3Suite extends CometTestBase with AdaptiveSparkPlanHelper { + + private var minioContainer: MinIOContainer = _ + private val userName = "minio-test-user" + private val password = "minio-test-password" + private val testBucketName = "test-bucket" + + override def beforeAll(): Unit = { + // Start MinIO container + minioContainer = new MinIOContainer(DockerImageName.parse("minio/minio:latest")) + .withUserName(userName) + .withPassword(password) + minioContainer.start() + createBucketIfNotExists(testBucketName) + + // Initialize Spark session + super.beforeAll() + } + + override def afterAll(): Unit = { + super.afterAll() + if (minioContainer != null) { + minioContainer.stop() + } + } + + override protected def sparkConf: SparkConf = { + val conf = super.sparkConf + conf.set("spark.hadoop.fs.s3a.access.key", userName) + conf.set("spark.hadoop.fs.s3a.secret.key", password) + conf.set("spark.hadoop.fs.s3a.endpoint", minioContainer.getS3URL) + conf.set("spark.hadoop.fs.s3a.path.style.access", "true") + } + + private def createBucketIfNotExists(bucketName: String): Unit = { + val credentials = AwsBasicCredentials.create(userName, password) + val s3Client = S3Client + .builder() + .endpointOverride(URI.create(minioContainer.getS3URL)) + .credentialsProvider(StaticCredentialsProvider.create(credentials)) + .forcePathStyle(true) + .build() + try { + val bucketExists = Try { + s3Client.headBucket(HeadBucketRequest.builder().bucket(bucketName).build()) + true + }.getOrElse(false) + + if (!bucketExists) { + val request = CreateBucketRequest.builder().bucket(bucketName).build() + s3Client.createBucket(request) + } + } finally { + s3Client.close() + } + } + + private def writeTestParquetFile(filePath: String): Unit = { + val df = spark.range(0, 1000) + df.write.format("parquet").mode(SaveMode.Overwrite).save(filePath) + } + + // native_iceberg_compat mode does not have comprehensive S3 support, so we don't run tests + // under this mode. + if (sys.env.getOrElse("COMET_PARQUET_SCAN_IMPL", "") != SCAN_NATIVE_ICEBERG_COMPAT) { + test("read parquet file from MinIO") { + val testFilePath = s"s3a://$testBucketName/data/test-file.parquet" + writeTestParquetFile(testFilePath) + + val df = spark.read.format("parquet").load(testFilePath).agg(sum(col("id"))) + val scans = collect(df.queryExecution.executedPlan) { + case p: CometScanExec => + p + case p: CometNativeScanExec => + p + } + assert(scans.size == 1) + + assert(df.first().getLong(0) == 499500) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org