This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 479208b0 refactor: port get_scan_files to Ballista (#877)
479208b0 is described below

commit 479208b0fb1b088538691f5364907f6b48ccd625
Author: Andrew Lamb <[email protected]>
AuthorDate: Fri Sep 15 12:09:26 2023 -0400

    refactor: port get_scan_files to Ballista (#877)
---
 ballista/scheduler/src/cluster/kv.rs     |  3 +--
 ballista/scheduler/src/cluster/memory.rs |  3 +--
 ballista/scheduler/src/cluster/mod.rs    | 30 ++++++++++++++++++++++++++++++
 3 files changed, 32 insertions(+), 4 deletions(-)

diff --git a/ballista/scheduler/src/cluster/kv.rs 
b/ballista/scheduler/src/cluster/kv.rs
index 53372f8d..25bb8e75 100644
--- a/ballista/scheduler/src/cluster/kv.rs
+++ b/ballista/scheduler/src/cluster/kv.rs
@@ -17,7 +17,7 @@
 
 use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, 
WatchEvent};
 use crate::cluster::{
-    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, 
get_scan_files,
     is_skip_consistent_hash, BoundTask, ClusterState, ExecutorHeartbeatStream,
     ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, JobStatus,
     TaskDistributionPolicy, TopologyNode,
@@ -39,7 +39,6 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::serde::BallistaCodec;
 use dashmap::DashMap;
-use datafusion::datasource::physical_plan::get_scan_files;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
diff --git a/ballista/scheduler/src/cluster/memory.rs 
b/ballista/scheduler/src/cluster/memory.rs
index 03a9358d..f2fe589a 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::cluster::{
-    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, 
get_scan_files,
     is_skip_consistent_hash, BoundTask, ClusterState, ExecutorSlot, JobState,
     JobStateEvent, JobStateEventStream, JobStatus, TaskDistributionPolicy, 
TopologyNode,
 };
@@ -42,7 +42,6 @@ use std::collections::{HashMap, HashSet};
 use std::ops::DerefMut;
 
 use ballista_core::consistent_hash::node::Node;
-use datafusion::datasource::physical_plan::get_scan_files;
 use datafusion::physical_plan::ExecutionPlan;
 use std::sync::Arc;
 use tokio::sync::{Mutex, MutexGuard};
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index 12938aa1..793d3fc1 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -21,7 +21,11 @@ use std::pin::Pin;
 use std::sync::Arc;
 
 use clap::ArgEnum;
+use datafusion::common::tree_node::TreeNode;
+use datafusion::common::tree_node::VisitRecursion;
 use datafusion::datasource::listing::PartitionedFile;
+use datafusion::datasource::physical_plan::{AvroExec, CsvExec, NdJsonExec, 
ParquetExec};
+use datafusion::error::DataFusionError;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -680,6 +684,32 @@ pub(crate) fn is_skip_consistent_hash(scan_files: 
&[Vec<Vec<PartitionedFile>>])
     scan_files.is_empty() || scan_files.len() > 1
 }
 
+/// Get all of the [`PartitionedFile`] to be scanned for an [`ExecutionPlan`]
+pub(crate) fn get_scan_files(
+    plan: Arc<dyn ExecutionPlan>,
+) -> std::result::Result<Vec<Vec<Vec<PartitionedFile>>>, DataFusionError> {
+    let mut collector: Vec<Vec<Vec<PartitionedFile>>> = vec![];
+    plan.apply(&mut |plan| {
+        let plan_any = plan.as_any();
+        let file_groups =
+            if let Some(parquet_exec) = plan_any.downcast_ref::<ParquetExec>() 
{
+                parquet_exec.base_config().file_groups.clone()
+            } else if let Some(avro_exec) = 
plan_any.downcast_ref::<AvroExec>() {
+                avro_exec.base_config().file_groups.clone()
+            } else if let Some(json_exec) = 
plan_any.downcast_ref::<NdJsonExec>() {
+                json_exec.base_config().file_groups.clone()
+            } else if let Some(csv_exec) = plan_any.downcast_ref::<CsvExec>() {
+                csv_exec.base_config().file_groups.clone()
+            } else {
+                return Ok(VisitRecursion::Continue);
+            };
+
+        collector.push(file_groups);
+        Ok(VisitRecursion::Skip)
+    })?;
+    Ok(collector)
+}
+
 #[derive(Clone)]
 pub struct TopologyNode {
     pub id: String,

Reply via email to