This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ray.git


The following commit(s) were added to refs/heads/main by this push:
     new 42681a1  Add object store support (#78)
42681a1 is described below

commit 42681a1cef90486f1172bb5167ad2107ee337293
Author: robtandy <[email protected]>
AuthorDate: Sun Mar 9 14:56:57 2025 -0400

    Add object store support (#78)
---
 Cargo.toml               |   2 +-
 README.md                |  12 ++++-
 datafusion_ray/core.py   |  93 +++++++++++++++++++++++-------------
 examples/http_csv.py     |  40 ++++++++++++++++
 examples/tips.py         |  23 ++-------
 src/context.rs           |  47 +++++++++++-------
 src/processor_service.rs |  26 +++++-----
 src/util.rs              | 122 +++++++++++++++++++++++++++++++++++++++++++++--
 tpch/tpcbench.py         |  13 ++---
 9 files changed, 282 insertions(+), 96 deletions(-)

diff --git a/Cargo.toml b/Cargo.toml
index fa4dc89..9d9b659 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -52,6 +52,7 @@ object_store = { version = "0.11.0", features = [
 ] }
 parking_lot = { version = "0.12", features = ["deadlock_detection"] }
 prost = "0.13"
+protobuf-src = "2.1"
 pyo3 = { version = "0.23", features = [
   "extension-module",
   "abi3",
@@ -85,7 +86,6 @@ tonic-build = { version = "0.8", default-features = false, 
features = [
   "prost",
 ] }
 url = "2"
-protobuf-src = "2.1"
 
 [dev-dependencies]
 tempfile = "3.17"
diff --git a/README.md b/README.md
index 881a579..540f44d 100644
--- a/README.md
+++ b/README.md
@@ -60,12 +60,20 @@ Once installed, you can run queries using DataFusion's 
familiar API while levera
 capabilities of Ray.
 
 ```python
+# from example in ./examples/http_csv.py
 import ray
 from datafusion_ray import DFRayContext, df_ray_runtime_env
 
 ray.init(runtime_env=df_ray_runtime_env)
-session = DFRayContext()
-df = session.sql("SELECT * FROM my_table WHERE value > 100")
+
+ctx = DFRayContext()
+ctx.register_csv(
+    "aggregate_test_100",
+    
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv";,
+)
+
+df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")
+
 df.show()
 ```
 
diff --git a/datafusion_ray/core.py b/datafusion_ray/core.py
index 70955bc..0d1736a 100644
--- a/datafusion_ray/core.py
+++ b/datafusion_ray/core.py
@@ -86,9 +86,7 @@ async def wait_for(coros, name=""):
     # wrap the coro in a task to work with python 3.10 and 3.11+ where 
asyncio.wait semantics
     # changed to not accept any awaitable
     start = time.time()
-    done, _ = await asyncio.wait(
-        [asyncio.create_task(_ensure_coro(c)) for c in coros]
-    )
+    done, _ = await asyncio.wait([asyncio.create_task(_ensure_coro(c)) for c 
in coros])
     end = time.time()
     log.info(f"waiting for {name} took {end - start}s")
     for d in done:
@@ -166,9 +164,7 @@ class DFRayProcessorPool:
         need_to_make = need - have
 
         if need_to_make > can_make:
-            raise Exception(
-                f"Cannot allocate workers above {self.max_workers}"
-            )
+            raise Exception(f"Cannot allocate workers above 
{self.max_workers}")
 
         if need_to_make > 0:
             log.debug(f"creating {need_to_make} additional processors")
@@ -197,9 +193,9 @@ class DFRayProcessorPool:
         self.processors_ready.clear()
         processor_key = new_friendly_name()
         log.debug(f"starting processor: {processor_key}")
-        processor = DFRayProcessor.options(
-            name=f"Processor : {processor_key}"
-        ).remote(processor_key)
+        processor = DFRayProcessor.options(name=f"Processor : 
{processor_key}").remote(
+            processor_key
+        )
         self.pool[processor_key] = processor
         self.processors_started.add(processor.start_up.remote())
         self.available.add(processor_key)
@@ -248,9 +244,7 @@ class DFRayProcessorPool:
 
     async def all_done(self):
         log.info("calling processor all done")
-        refs = [
-            processor.all_done.remote() for processor in self.pool.values()
-        ]
+        refs = [processor.all_done.remote() for processor in 
self.pool.values()]
         await wait_for(refs, "processors to be all done")
         log.info("all processors shutdown")
 
@@ -293,9 +287,7 @@ class DFRayProcessor:
         )
 
     async def serve(self):
-        log.info(
-            f"[{self.processor_key}] serving on 
{self.processor_service.addr()}"
-        )
+        log.info(f"[{self.processor_key}] serving on 
{self.processor_service.addr()}")
         await self.processor_service.serve()
         log.info(f"[{self.processor_key}] done serving")
 
@@ -332,9 +324,7 @@ class DFRayContextSupervisor:
         worker_pool_min: int,
         worker_pool_max: int,
     ) -> None:
-        log.info(
-            f"Creating DFRayContextSupervisor worker_pool_min: 
{worker_pool_min}"
-        )
+        log.info(f"Creating DFRayContextSupervisor worker_pool_min: 
{worker_pool_min}")
         self.pool = DFRayProcessorPool(worker_pool_min, worker_pool_max)
         self.stages: dict[str, InternalStageData] = {}
         log.info("Created DFRayContextSupervisor")
@@ -347,9 +337,7 @@ class DFRayContextSupervisor:
 
     async def get_stage_addrs(self, stage_id: int):
         addrs = [
-            sd.remote_addr
-            for sd in self.stages.values()
-            if sd.stage_id == stage_id
+            sd.remote_addr for sd in self.stages.values() if sd.stage_id == 
stage_id
         ]
         return addrs
 
@@ -399,10 +387,7 @@ class DFRayContextSupervisor:
             refs.append(
                 isd.remote_processor.update_plan.remote(
                     isd.stage_id,
-                    {
-                        stage_id: val["child_addrs"]
-                        for (stage_id, val) in kid.items()
-                    },
+                    {stage_id: val["child_addrs"] for (stage_id, val) in 
kid.items()},
                     isd.partition_group,
                     isd.plan_bytes,
                 )
@@ -434,9 +419,7 @@ class DFRayContextSupervisor:
                 ]
 
                 # sanity check
-                assert all(
-                    [op == output_partitions[0] for op in output_partitions]
-                )
+                assert all([op == output_partitions[0] for op in 
output_partitions])
                 output_partitions = output_partitions[0]
 
                 for child_stage_isd in child_stage_datas:
@@ -520,9 +503,7 @@ class DFRayDataFrame:
             )
             log.debug(f"last stage addrs {last_stage_addrs}")
 
-            reader = self.df.read_final_stage(
-                last_stage_id, last_stage_addrs[0]
-            )
+            reader = self.df.read_final_stage(last_stage_id, 
last_stage_addrs[0])
             log.debug("got reader")
             self._batches = list(reader)
         return self._batches
@@ -589,11 +570,55 @@ class DFRayContext:
         )
 
     def register_parquet(self, name: str, path: str):
+        """
+        Register a Parquet file with the given name and path.
+        The path can be a local filesystem path, absolute filesystem path, or 
a url.
+
+        If the path is a object store url, the appropriate object store will 
be registered.
+        Configuration of the object store will be gathered from the 
environment.
+
+        For example for s3:// urls, credentials will be looked for by the AWS 
SDK,
+        which will check environment variables, credential files, etc
+
+        Parameters:
+        path (str): The file path to the Parquet file.
+        name (str): The name to register the Parquet file under.
+        """
         self.ctx.register_parquet(name, path)
 
-    def register_listing_table(
-        self, name: str, path: str, file_extention="parquet"
-    ):
+    def register_csv(self, name: str, path: str):
+        """
+        Register a csvfile with the given name and path.
+        The path can be a local filesystem path, absolute filesystem path, or 
a url.
+
+        If the path is a object store url, the appropriate object store will 
be registered.
+        Configuration of the object store will be gathered from the 
environment.
+
+        For example for s3:// urls, credentials will be looked for by the AWS 
SDK,
+        which will check environment variables, credential files, etc
+
+        Parameters:
+        path (str): The file path to the csv file.
+        name (str): The name to register the Parquet file under.
+        """
+        self.ctx.register_csv(name, path)
+
+    def register_listing_table(self, name: str, path: str, 
file_extention="parquet"):
+        """
+        Register a directory of parquet files with the given name.
+        The path can be a local filesystem path, absolute filesystem path, or 
a url.
+
+        If the path is a object store url, the appropriate object store will 
be registered.
+        Configuration of the object store will be gathered from the 
environment.
+
+        For example for s3:// urls, credentials will be looked for by the AWS 
SDK,
+        which will check environment variables, credential files, etc
+
+        Parameters:
+        path (str): The file path to the Parquet file directory
+        name (str): The name to register the Parquet file under.
+        """
+
         self.ctx.register_listing_table(name, path, file_extention)
 
     def sql(self, query: str) -> DFRayDataFrame:
diff --git a/examples/http_csv.py b/examples/http_csv.py
new file mode 100644
index 0000000..9fc7de4
--- /dev/null
+++ b/examples/http_csv.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# this is a port of the example at
+# 
https://github.com/apache/datafusion/blob/45.0.0/datafusion-examples/examples/query-http-csv.rs
+
+import ray
+
+from datafusion_ray import DFRayContext, df_ray_runtime_env
+
+
+def main():
+    ctx = DFRayContext()
+    ctx.register_csv(
+        "aggregate_test_100",
+        
"https://github.com/apache/arrow-testing/raw/master/data/csv/aggregate_test_100.csv";,
+    )
+
+    df = ctx.sql("SELECT c1,c2,c3 FROM aggregate_test_100 LIMIT 5")
+
+    df.show()
+
+
+if __name__ == __name__:
+    ray.init(namespace="http_csv", runtime_env=df_ray_runtime_env)
+    main()
diff --git a/examples/tips.py b/examples/tips.py
index 7d72ba5..7537f5f 100644
--- a/examples/tips.py
+++ b/examples/tips.py
@@ -16,40 +16,27 @@
 # under the License.
 
 import argparse
-import datafusion
+import os
 import ray
 
-from datafusion_ray import DFRayContext
+from datafusion_ray import DFRayContext, df_ray_runtime_env
 
 
 def go(data_dir: str):
     ctx = DFRayContext()
-    # we could set this value to however many CPUs we plan to give each
-    # ray task
-    ctx.set("datafusion.execution.target_partitions", "1")
-    ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
 
-    ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
+    ctx.register_parquet("tips", os.path.join(data_dir, "tips.parquet"))
 
     df = ctx.sql(
         "select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by 
sex, smoker order by sex, smoker"
     )
     df.show()
 
-    print("no ray result:")
-
-    # compare to non ray version
-    ctx = datafusion.SessionContext()
-    ctx.register_parquet("tips", f"{data_dir}/tips*.parquet")
-    ctx.sql(
-        "select sex, smoker, avg(tip/total_bill) as tip_pct from tips group by 
sex, smoker order by sex, smoker"
-    ).show()
-
 
 if __name__ == "__main__":
-    ray.init(namespace="tips")
+    ray.init(namespace="tips", runtime_env=df_ray_runtime_env)
     parser = argparse.ArgumentParser()
-    parser.add_argument("--data-dir", required=True, help="path to 
tips*.parquet files")
+    parser.add_argument("--data-dir", required=True, help="path to 
tips.parquet files")
     args = parser.parse_args()
 
     go(args.data_dir)
diff --git a/src/context.rs b/src/context.rs
index 191d632..ee76080 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -16,17 +16,17 @@
 // under the License.
 
 use datafusion::datasource::file_format::parquet::ParquetFormat;
-use datafusion::datasource::listing::ListingOptions;
-use datafusion::{execution::SessionStateBuilder, prelude::*};
+use datafusion::datasource::listing::{ListingOptions, ListingTableUrl};
+use datafusion::execution::SessionStateBuilder;
+use datafusion::prelude::{CsvReadOptions, ParquetReadOptions, SessionConfig, 
SessionContext};
 use datafusion_python::utils::wait_for_future;
-use object_store::aws::AmazonS3Builder;
+use log::debug;
 use pyo3::prelude::*;
 use std::sync::Arc;
 
 use crate::dataframe::DFRayDataFrame;
 use crate::physical::RayStageOptimizerRule;
-use crate::util::ResultExt;
-use url::Url;
+use crate::util::{maybe_register_object_store, ResultExt};
 
 /// Internal Session Context object for the python class DFRayContext
 #[pyclass]
@@ -54,23 +54,27 @@ impl DFRayContext {
         Ok(Self { ctx })
     }
 
-    pub fn register_s3(&self, bucket_name: String) -> PyResult<()> {
-        let s3 = AmazonS3Builder::from_env()
-            .with_bucket_name(&bucket_name)
-            .build()
-            .to_py_err()?;
+    pub fn register_parquet(&self, py: Python, name: String, path: String) -> 
PyResult<()> {
+        let options = ParquetReadOptions::default();
+
+        let url = ListingTableUrl::parse(&path).to_py_err()?;
 
-        let path = format!("s3://{bucket_name}");
-        let s3_url = Url::parse(&path).to_py_err()?;
-        let arc_s3 = Arc::new(s3);
-        self.ctx.register_object_store(&s3_url, arc_s3.clone());
+        maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
+        debug!("register_parquet: registering table {} at {}", name, path);
+
+        wait_for_future(py, self.ctx.register_parquet(&name, &path, 
options.clone()))?;
         Ok(())
     }
 
-    pub fn register_parquet(&self, py: Python, name: String, path: String) -> 
PyResult<()> {
-        let options = ParquetReadOptions::default();
+    pub fn register_csv(&self, py: Python, name: String, path: String) -> 
PyResult<()> {
+        let options = CsvReadOptions::default();
 
-        wait_for_future(py, self.ctx.register_parquet(&name, &path, 
options.clone()))?;
+        let url = ListingTableUrl::parse(&path).to_py_err()?;
+
+        maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
+        debug!("register_csv: registering table {} at {}", name, path);
+
+        wait_for_future(py, self.ctx.register_csv(&name, &path, 
options.clone()))?;
         Ok(())
     }
 
@@ -85,6 +89,15 @@ impl DFRayContext {
         let options =
             
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension);
 
+        let path = format!("{path}/");
+        let url = ListingTableUrl::parse(&path).to_py_err()?;
+
+        maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
+
+        debug!(
+            "register_listing_table: registering table {} at {}",
+            name, path
+        );
         wait_for_future(
             py,
             self.ctx
diff --git a/src/processor_service.rs b/src/processor_service.rs
index 5164577..120ba21 100644
--- a/src/processor_service.rs
+++ b/src/processor_service.rs
@@ -15,15 +15,15 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::hash_map::Entry;
 use std::collections::HashMap;
+use std::collections::hash_map::Entry;
 use std::error::Error;
 use std::sync::Arc;
 
 use arrow::array::RecordBatch;
+use arrow_flight::FlightClient;
 use arrow_flight::encode::FlightDataEncoderBuilder;
 use arrow_flight::error::FlightError;
-use arrow_flight::FlightClient;
 use datafusion::common::internal_datafusion_err;
 use datafusion::execution::SessionStateBuilder;
 use datafusion::physical_plan::ExecutionPlan;
@@ -35,23 +35,23 @@ use log::{debug, error, info, trace};
 use tokio::net::TcpListener;
 
 use tonic::transport::Server;
-use tonic::{async_trait, Request, Response, Status};
+use tonic::{Request, Response, Status, async_trait};
 
 use datafusion::error::Result as DFResult;
 
-use arrow_flight::{flight_service_server::FlightServiceServer, Ticket};
+use arrow_flight::{Ticket, flight_service_server::FlightServiceServer};
 
 use pyo3::prelude::*;
 
 use parking_lot::{Mutex, RwLock};
 
-use tokio::sync::mpsc::{channel, Receiver, Sender};
+use tokio::sync::mpsc::{Receiver, Sender, channel};
 
 use crate::flight::{FlightHandler, FlightServ};
 use crate::isolator::PartitionGroup;
 use crate::util::{
-    bytes_to_physical_plan, display_plan_with_partition_counts, 
extract_ticket, input_stage_ids,
-    make_client, ResultExt,
+    ResultExt, bytes_to_physical_plan, display_plan_with_partition_counts, 
extract_ticket,
+    input_stage_ids, make_client, register_object_store_for_paths_in_plan,
 };
 
 /// a map of stage_id, partition to a list FlightClients that can serve
@@ -102,7 +102,7 @@ impl DFRayProcessorHandlerInner {
         plan: Arc<dyn ExecutionPlan>,
         partition_group: Vec<usize>,
     ) -> DFResult<Self> {
-        let ctx = Self::configure_ctx(stage_id, stage_addrs, &plan, 
partition_group).await?;
+        let ctx = Self::configure_ctx(stage_id, stage_addrs, plan.clone(), 
partition_group).await?;
 
         Ok(Self { plan, ctx })
     }
@@ -110,10 +110,10 @@ impl DFRayProcessorHandlerInner {
     async fn configure_ctx(
         stage_id: usize,
         stage_addrs: HashMap<usize, HashMap<usize, Vec<String>>>,
-        plan: &Arc<dyn ExecutionPlan>,
+        plan: Arc<dyn ExecutionPlan>,
         partition_group: Vec<usize>,
     ) -> DFResult<SessionContext> {
-        let stage_ids_i_need = input_stage_ids(plan)?;
+        let stage_ids_i_need = input_stage_ids(&plan)?;
 
         // map of stage_id, partition -> Vec<FlightClient>
         let mut client_map = HashMap::new();
@@ -163,6 +163,8 @@ impl DFRayProcessorHandlerInner {
             .build();
         let ctx = SessionContext::new_with_state(state);
 
+        register_object_store_for_paths_in_plan(&ctx, plan.clone())?;
+
         trace!("ctx configured for stage {}", stage_id);
 
         Ok(ctx)
@@ -212,9 +214,7 @@ impl FlightHandler for DFRayProcessorHandler {
 
         trace!(
             "{}, request for partition {} from {}",
-            self.name,
-            partition,
-            remote_addr
+            self.name, partition, remote_addr
         );
 
         let name = self.name.clone();
diff --git a/src/util.rs b/src/util.rs
index 0c07c5c..1fa36e8 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -22,9 +22,12 @@ use async_stream::stream;
 use datafusion::common::internal_datafusion_err;
 use datafusion::common::tree_node::{Transformed, TreeNode};
 use datafusion::datasource::file_format::parquet::ParquetFormat;
-use datafusion::datasource::listing::ListingOptions;
-use datafusion::datasource::physical_plan::ParquetExec;
+use datafusion::datasource::listing::{ListingOptions, ListingTableUrl};
+use datafusion::datasource::physical_plan::{
+    ArrowExec, AvroExec, CsvExec, NdJsonExec, ParquetExec,
+};
 use datafusion::error::DataFusionError;
+use datafusion::execution::object_store::ObjectStoreUrl;
 use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, 
SessionStateBuilder};
 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use datafusion::physical_plan::{displayable, ExecutionPlan, 
ExecutionPlanProperties};
@@ -32,10 +35,16 @@ use datafusion::prelude::{SessionConfig, SessionContext};
 use datafusion_proto::physical_plan::AsExecutionPlan;
 use datafusion_python::utils::wait_for_future;
 use futures::{Stream, StreamExt};
+use log::debug;
+use object_store::aws::AmazonS3Builder;
+use object_store::gcp::GoogleCloudStorageBuilder;
+use object_store::http::HttpBuilder;
+use object_store::ObjectStore;
 use parking_lot::Mutex;
 use pyo3::prelude::*;
 use pyo3::types::{PyBytes, PyList};
 use tonic::transport::Channel;
+use url::Url;
 
 use crate::codec::RayCodec;
 use crate::processor_service::ServiceClients;
@@ -410,6 +419,12 @@ async fn exec_sql(
     for (name, path) in tables {
         let opt =
             
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(".parquet");
+        debug!("exec_sql: registering table {} at {}", name, path);
+
+        let url = ListingTableUrl::parse(&path)?;
+
+        maybe_register_object_store(&ctx, url.as_ref())?;
+
         ctx.register_listing_table(&name, &path, opt, None, None)
             .await?;
     }
@@ -432,16 +447,21 @@ async fn exec_sql(
 ///   the `url` identifies the parquet files for each listing table and see
 ///   [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
 ///   of supported URL formats
+///  * `listing`: boolean indicating whether this is a listing table path or 
not
 #[pyfunction]
+#[pyo3(signature = (query, tables, listing=false))]
 pub fn exec_sql_on_tables(
     py: Python,
     query: String,
     tables: Bound<'_, PyList>,
+    listing: bool,
 ) -> PyResult<PyObject> {
     let table_vec = {
         let mut v = Vec::with_capacity(tables.len());
         for entry in tables.iter() {
-            v.push(entry.extract::<(String, String)>()?);
+            let (name, path) = entry.extract::<(String, String)>()?;
+            let path = if listing { format!("{path}/") } else { path };
+            v.push((name, path));
         }
         v
     };
@@ -449,6 +469,102 @@ pub fn exec_sql_on_tables(
     batch.to_pyarrow(py)
 }
 
+pub(crate) fn register_object_store_for_paths_in_plan(
+    ctx: &SessionContext,
+    plan: Arc<dyn ExecutionPlan>,
+) -> Result<(), DataFusionError> {
+    let check_plan = |plan: Arc<dyn ExecutionPlan>| -> Result<_, 
DataFusionError> {
+        for input in plan.children().into_iter() {
+            if let Some(node) = input.as_any().downcast_ref::<ParquetExec>() {
+                let url = &node.base_config().object_store_url;
+                maybe_register_object_store(ctx, url.as_ref())?
+            } else if let Some(node) = 
input.as_any().downcast_ref::<CsvExec>() {
+                let url = &node.base_config().object_store_url;
+                maybe_register_object_store(ctx, url.as_ref())?
+            } else if let Some(node) = 
input.as_any().downcast_ref::<NdJsonExec>() {
+                let url = &node.base_config().object_store_url;
+                maybe_register_object_store(ctx, url.as_ref())?
+            } else if let Some(node) = 
input.as_any().downcast_ref::<AvroExec>() {
+                let url = &node.base_config().object_store_url;
+                maybe_register_object_store(ctx, url.as_ref())?
+            } else if let Some(node) = 
input.as_any().downcast_ref::<ArrowExec>() {
+                let url = &node.base_config().object_store_url;
+                maybe_register_object_store(ctx, url.as_ref())?
+            }
+        }
+        Ok(Transformed::no(plan))
+    };
+
+    plan.transform_down(check_plan)?;
+
+    Ok(())
+}
+
+/// Registers an object store with the given session context based on the 
provided path.
+///
+/// # Arguments
+///
+/// * `ctx` - A reference to the `SessionContext` where the object store will 
be registered.
+/// * `path` - A string slice that holds the path or URL of the object store.
+pub(crate) fn maybe_register_object_store(
+    ctx: &SessionContext,
+    url: &Url,
+) -> Result<(), DataFusionError> {
+    let (ob_url, object_store) = if url.as_str().starts_with("s3://") {
+        let bucket = url
+            .host_str()
+            .ok_or(internal_datafusion_err!("missing bucket name in s3:// 
url"))?;
+
+        let s3 = AmazonS3Builder::from_env()
+            .with_bucket_name(bucket)
+            .build()?;
+        (
+            ObjectStoreUrl::parse(format!("s3://{bucket}"))?,
+            Arc::new(s3) as Arc<dyn ObjectStore>,
+        )
+    } else if url.as_str().starts_with("gs://") || 
url.as_str().starts_with("gcs://") {
+        let bucket = url
+            .host_str()
+            .ok_or(internal_datafusion_err!("missing bucket name in gs:// 
url"))?;
+
+        let gs = GoogleCloudStorageBuilder::new()
+            .with_bucket_name(bucket)
+            .build()?;
+
+        (
+            ObjectStoreUrl::parse(format!("gs://{bucket}"))?,
+            Arc::new(gs) as Arc<dyn ObjectStore>,
+        )
+    } else if url.as_str().starts_with("http://";) || 
url.as_str().starts_with("https://";) {
+        let scheme = url.scheme();
+
+        let host = url.host_str().ok_or(internal_datafusion_err!(
+            "missing host name in {}:// url",
+            scheme
+        ))?;
+
+        let http = HttpBuilder::new()
+            .with_url(format!("{scheme}://{host}"))
+            .build()?;
+
+        (
+            ObjectStoreUrl::parse(format!("{scheme}://{host}"))?,
+            Arc::new(http) as Arc<dyn ObjectStore>,
+        )
+    } else {
+        let local = object_store::local::LocalFileSystem::new();
+        (
+            ObjectStoreUrl::parse("file://")?,
+            Arc::new(local) as Arc<dyn ObjectStore>,
+        )
+    };
+
+    debug!("Registering object store for {}", ob_url);
+
+    ctx.register_object_store(ob_url.as_ref(), object_store);
+    Ok(())
+}
+
 #[cfg(test)]
 mod test {
     use std::{sync::Arc, vec};
diff --git a/tpch/tpcbench.py b/tpch/tpcbench.py
index 13960bf..dd6df1e 100644
--- a/tpch/tpcbench.py
+++ b/tpch/tpcbench.py
@@ -72,7 +72,7 @@ def main(
         path = os.path.join(data_path, f"{table}.parquet")
         print(f"Registering table {table} using path {path}")
         if listing_tables:
-            ctx.register_listing_table(table, f"{path}/")
+            ctx.register_listing_table(table, path)
         else:
             ctx.register_parquet(table, path)
 
@@ -93,7 +93,6 @@ def main(
         "queries": {},
     }
     if validate:
-        results["local_queries"] = {}
         results["validated"] = {}
 
     queries = range(1, 23) if qnum == -1 else [qnum]
@@ -114,15 +113,13 @@ def main(
         calculated = prettify(batches)
         print(calculated)
         if validate:
-            start_time = time.time()
             tables = [
                 (name, os.path.join(data_path, f"{name}.parquet"))
                 for name in table_names
             ]
-            answer_batches = [b for b in [exec_sql_on_tables(sql, tables)] if 
b]
-            end_time = time.time()
-            results["local_queries"][qnum] = end_time - start_time
-
+            answer_batches = [
+                b for b in [exec_sql_on_tables(sql, tables, listing_tables)] 
if b
+            ]
             expected = prettify(answer_batches)
 
             results["validated"][qnum] = calculated == expected
@@ -137,7 +134,7 @@ def main(
         print(results_dump)
 
     # give ray a moment to clean up
-    print("sleeping for 3 seconds for ray to clean up")
+    print("benchmark complete. sleeping for 3 seconds for ray to clean up")
     time.sleep(3)
 
     if validate and False in results["validated"].values():


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to