This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new ba4d9d3b Implement 3-phase consistent hash based task assignment 
policy (#833)
ba4d9d3b is described below

commit ba4d9d3b065cf3bf91fd948408dac514cc0c905a
Author: yahoNanJing <[email protected]>
AuthorDate: Thu Jul 20 10:19:11 2023 +0800

    Implement 3-phase consistent hash based task assignment policy (#833)
    
    * Implement 3-phase consistent hash based task assignment policy
    
    * Add unit test for bind_task_consistent_hash
    
    * Add an option ballista.data_cache.enabled to indicate whether to enable 
data cache at the executor side
    
    * Fix for compilation and cargo clippy
    
    ---------
    
    Co-authored-by: yangzhong <[email protected]>
---
 ballista/core/src/config.rs                     |   2 +
 ballista/executor/src/execution_loop.rs         |   2 +-
 ballista/executor/src/executor.rs               |  22 +-
 ballista/executor/src/executor_process.rs       |  30 +-
 ballista/executor/src/executor_server.rs        |  14 +-
 ballista/executor/src/standalone.rs             |   1 +
 ballista/scheduler/scheduler_config_spec.toml   |  14 +-
 ballista/scheduler/src/bin/main.rs              |  19 +-
 ballista/scheduler/src/cluster/kv.rs            | 106 ++++++-
 ballista/scheduler/src/cluster/memory.rs        | 105 ++++++-
 ballista/scheduler/src/cluster/mod.rs           | 402 +++++++++++++++++++++---
 ballista/scheduler/src/config.rs                |  31 +-
 ballista/scheduler/src/scheduler_server/grpc.rs |  10 +-
 ballista/scheduler/src/state/execution_graph.rs |   5 +-
 ballista/scheduler/src/state/task_manager.rs    | 112 +++++--
 15 files changed, 763 insertions(+), 112 deletions(-)

diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index dc0cc7fe..058cbe4d 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -37,6 +37,8 @@ pub const BALLISTA_REPARTITION_AGGREGATIONS: &str = 
"ballista.repartition.aggreg
 pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows";
 pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning";
 pub const BALLISTA_COLLECT_STATISTICS: &str = "ballista.collect_statistics";
+/// Indicate whether to enable to data cache for a task
+pub const BALLISTA_DATA_CACHE_ENABLED: &str = "ballista.data_cache.enabled";
 
 pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = 
"ballista.with_information_schema";
 /// give a plugin files dir, and then the dynamic library files in this dir 
will be load when scheduler state init.
diff --git a/ballista/executor/src/execution_loop.rs 
b/ballista/executor/src/execution_loop.rs
index 6d7b4aec..6b897b56 100644
--- a/ballista/executor/src/execution_loop.rs
+++ b/ballista/executor/src/execution_loop.rs
@@ -192,7 +192,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionP
     for window_func in executor.window_functions.clone() {
         task_window_functions.insert(window_func.0, window_func.1);
     }
-    let runtime = executor.runtime.clone();
+    let runtime = executor.get_runtime(false);
     let session_id = task.session_id.clone();
     let task_context = Arc::new(TaskContext::new(
         Some(task_identity.clone()),
diff --git a/ballista/executor/src/executor.rs 
b/ballista/executor/src/executor.rs
index f5828bbb..4ee57eb8 100644
--- a/ballista/executor/src/executor.rs
+++ b/ballista/executor/src/executor.rs
@@ -73,7 +73,12 @@ pub struct Executor {
     pub window_functions: HashMap<String, Arc<WindowUDF>>,
 
     /// Runtime environment for Executor
-    pub runtime: Arc<RuntimeEnv>,
+    runtime: Arc<RuntimeEnv>,
+
+    /// Runtime environment for Executor with data cache.
+    /// The difference with [`runtime`] is that it leverages a different 
[`object_store_registry`].
+    /// And others things are shared with [`runtime`].
+    runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
 
     /// Collector for runtime execution metrics
     pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
@@ -95,6 +100,7 @@ impl Executor {
         metadata: ExecutorRegistration,
         work_dir: &str,
         runtime: Arc<RuntimeEnv>,
+        runtime_with_data_cache: Option<Arc<RuntimeEnv>>,
         metrics_collector: Arc<dyn ExecutorMetricsCollector>,
         concurrent_tasks: usize,
         execution_engine: Option<Arc<dyn ExecutionEngine>>,
@@ -107,6 +113,7 @@ impl Executor {
             aggregate_functions: HashMap::new(),
             window_functions: HashMap::new(),
             runtime,
+            runtime_with_data_cache,
             metrics_collector,
             concurrent_tasks,
             abort_handles: Default::default(),
@@ -117,6 +124,18 @@ impl Executor {
 }
 
 impl Executor {
+    pub fn get_runtime(&self, data_cache: bool) -> Arc<RuntimeEnv> {
+        if data_cache {
+            if let Some(runtime) = self.runtime_with_data_cache.clone() {
+                runtime
+            } else {
+                self.runtime.clone()
+            }
+        } else {
+            self.runtime.clone()
+        }
+    }
+
     /// Execute one partition of a query stage and persist the result to disk 
in IPC format. On
     /// success, return a RecordBatch containing metadata about the results, 
including path
     /// and statistics.
@@ -319,6 +338,7 @@ mod test {
             executor_registration,
             &work_dir,
             ctx.runtime_env(),
+            None,
             Arc::new(LoggingMetricsCollector {}),
             2,
             None,
diff --git a/ballista/executor/src/executor_process.rs 
b/ballista/executor/src/executor_process.rs
index 5fbb38ca..9d227981 100644
--- a/ballista/executor/src/executor_process.rs
+++ b/ballista/executor/src/executor_process.rs
@@ -47,7 +47,7 @@ use ballista_core::cache_layer::{
 use ballista_core::config::{DataCachePolicy, LogRotationPolicy, 
TaskSchedulingPolicy};
 use ballista_core::error::BallistaError;
 #[cfg(not(windows))]
-use ballista_core::object_store_registry::cache::with_cache_layer;
+use 
ballista_core::object_store_registry::cache::CachedBasedObjectStoreRegistry;
 use ballista_core::object_store_registry::with_object_store_registry;
 use ballista_core::serde::protobuf::executor_resource::Resource;
 use ballista_core::serde::protobuf::executor_status::Status;
@@ -186,9 +186,16 @@ pub async fn start_executor_process(opt: 
Arc<ExecutorProcessConfig>) -> Result<(
     };
 
     let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone());
+    let runtime = {
+        let config = with_object_store_registry(config.clone());
+        Arc::new(RuntimeEnv::new(config).map_err(|_| {
+            BallistaError::Internal("Failed to init Executor 
RuntimeEnv".to_owned())
+        })?)
+    };
+
     // Set the object store registry
     #[cfg(not(windows))]
-    let config = {
+    let runtime_with_data_cache = {
         let cache_dir = opt.cache_dir.clone();
         let cache_capacity = opt.cache_capacity;
         let cache_io_concurrency = opt.cache_io_concurrency;
@@ -206,17 +213,21 @@ pub async fn start_executor_process(opt: 
Arc<ExecutorProcessConfig>) -> Result<(
                     }
                 });
         if let Some(cache_layer) = cache_layer {
-            with_cache_layer(config, cache_layer)
+            let registry = Arc::new(CachedBasedObjectStoreRegistry::new(
+                runtime.object_store_registry.clone(),
+                cache_layer,
+            ));
+            Some(Arc::new(RuntimeEnv {
+                memory_pool: runtime.memory_pool.clone(),
+                disk_manager: runtime.disk_manager.clone(),
+                object_store_registry: registry,
+            }))
         } else {
-            with_object_store_registry(config)
+            None
         }
     };
     #[cfg(windows)]
-    let config = with_object_store_registry(config);
-
-    let runtime = Arc::new(RuntimeEnv::new(config).map_err(|_| {
-        BallistaError::Internal("Failed to init Executor 
RuntimeEnv".to_owned())
-    })?);
+    let runtime_with_data_cache = { None };
 
     let metrics_collector = Arc::new(LoggingMetricsCollector::default());
 
@@ -224,6 +235,7 @@ pub async fn start_executor_process(opt: 
Arc<ExecutorProcessConfig>) -> Result<(
         executor_meta,
         &work_dir,
         runtime,
+        runtime_with_data_cache,
         metrics_collector,
         concurrent_tasks,
         opt.execution_engine.clone(),
diff --git a/ballista/executor/src/executor_server.rs 
b/ballista/executor/src/executor_server.rs
index 6341a38d..57223c54 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -28,6 +28,7 @@ use log::{debug, error, info, warn};
 use tonic::transport::Channel;
 use tonic::{Request, Response, Status};
 
+use ballista_core::config::BALLISTA_DATA_CACHE_ENABLED;
 use ballista_core::error::BallistaError;
 use ballista_core::serde::protobuf::{
     executor_grpc_server::{ExecutorGrpc, ExecutorGrpcServer},
@@ -334,6 +335,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
 
         let task_context = {
             let task_props = task.props;
+            let data_cache = task_props
+                .get(BALLISTA_DATA_CACHE_ENABLED)
+                .map(|data_cache| data_cache.parse().unwrap_or(false))
+                .unwrap_or(false);
             let mut config = ConfigOptions::new();
             for (k, v) in task_props.iter() {
                 if let Err(e) = config.set(k, v) {
@@ -343,7 +348,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
             let session_config = SessionConfig::from(config);
 
             let function_registry = task.function_registry;
-            let runtime = self.executor.runtime.clone();
+            if data_cache {
+                info!("Data cache will be enabled for {}", task_identity);
+            }
+            let runtime = self.executor.get_runtime(data_cache);
 
             Arc::new(TaskContext::new(
                 Some(task_identity.clone()),
@@ -632,7 +640,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorGrpc
                     scheduler_id: scheduler_id.clone(),
                     task: get_task_definition(
                         task,
-                        self.executor.runtime.clone(),
+                        self.executor.get_runtime(false),
                         self.executor.scalar_functions.clone(),
                         self.executor.aggregate_functions.clone(),
                         self.executor.window_functions.clone(),
@@ -660,7 +668,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorGrpc
         for multi_task in multi_tasks {
             let multi_task: Vec<TaskDefinition> = get_task_definition_vec(
                 multi_task,
-                self.executor.runtime.clone(),
+                self.executor.get_runtime(false),
                 self.executor.scalar_functions.clone(),
                 self.executor.aggregate_functions.clone(),
                 self.executor.window_functions.clone(),
diff --git a/ballista/executor/src/standalone.rs 
b/ballista/executor/src/standalone.rs
index 38e27713..971af517 100644
--- a/ballista/executor/src/standalone.rs
+++ b/ballista/executor/src/standalone.rs
@@ -82,6 +82,7 @@ pub async fn new_standalone_executor<
         executor_meta,
         &work_dir,
         Arc::new(RuntimeEnv::new(config).unwrap()),
+        None,
         Arc::new(LoggingMetricsCollector::default()),
         concurrent_tasks,
         None,
diff --git a/ballista/scheduler/scheduler_config_spec.toml 
b/ballista/scheduler/scheduler_config_spec.toml
index 1aac653a..c9f5154a 100644
--- a/ballista/scheduler/scheduler_config_spec.toml
+++ b/ballista/scheduler/scheduler_config_spec.toml
@@ -97,9 +97,21 @@ doc = "Delayed interval for cleaning up finished job state. 
Default: 3600"
 [[param]]
 name = "task_distribution"
 type = "ballista_scheduler::config::TaskDistribution"
-doc = "The policy of distributing tasks to available executor slots, possible 
values: bias, round-robin. Default: bias"
+doc = "The policy of distributing tasks to available executor slots, possible 
values: bias, round-robin, consistent-hash. Default: bias"
 default = "ballista_scheduler::config::TaskDistribution::Bias"
 
+[[param]]
+name = "consistent_hash_num_replicas"
+type = "u32"
+default = "31"
+doc = "Replica number of each node for the consistent hashing. Default: 31"
+
+[[param]]
+name = "consistent_hash_tolerance"
+type = "u32"
+default = "0"
+doc = "Tolerance of the consistent hashing policy for task scheduling. 
Default: 0"
+
 [[param]]
 name = "plugin_dir"
 type = "String"
diff --git a/ballista/scheduler/src/bin/main.rs 
b/ballista/scheduler/src/bin/main.rs
index ec59b409..46ac7cc9 100644
--- a/ballista/scheduler/src/bin/main.rs
+++ b/ballista/scheduler/src/bin/main.rs
@@ -27,7 +27,9 @@ use ballista_core::config::LogRotationPolicy;
 use ballista_core::print_version;
 use ballista_scheduler::cluster::BallistaCluster;
 use ballista_scheduler::cluster::ClusterStorage;
-use ballista_scheduler::config::{ClusterStorageConfig, SchedulerConfig};
+use ballista_scheduler::config::{
+    ClusterStorageConfig, SchedulerConfig, TaskDistribution, 
TaskDistributionPolicy,
+};
 use ballista_scheduler::scheduler_process::start_server;
 use tracing_subscriber::EnvFilter;
 
@@ -121,13 +123,26 @@ async fn main() -> Result<()> {
         }
     };
 
+    let task_distribution = match opt.task_distribution {
+        TaskDistribution::Bias => TaskDistributionPolicy::Bias,
+        TaskDistribution::RoundRobin => TaskDistributionPolicy::RoundRobin,
+        TaskDistribution::ConsistentHash => {
+            let num_replicas = opt.consistent_hash_num_replicas as usize;
+            let tolerance = opt.consistent_hash_tolerance as usize;
+            TaskDistributionPolicy::ConsistentHash {
+                num_replicas,
+                tolerance,
+            }
+        }
+    };
+
     let config = SchedulerConfig {
         namespace: opt.namespace,
         external_host: opt.external_host,
         bind_port: opt.bind_port,
         scheduling_policy: opt.scheduler_policy,
         event_loop_buffer_size: opt.event_loop_buffer_size,
-        task_distribution: opt.task_distribution,
+        task_distribution,
         finished_job_data_clean_up_interval_seconds: opt
             .finished_job_data_clean_up_interval_seconds,
         finished_job_state_clean_up_interval_seconds: opt
diff --git a/ballista/scheduler/src/cluster/kv.rs 
b/ballista/scheduler/src/cluster/kv.rs
index a8852fb6..53372f8d 100644
--- a/ballista/scheduler/src/cluster/kv.rs
+++ b/ballista/scheduler/src/cluster/kv.rs
@@ -17,9 +17,10 @@
 
 use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, 
WatchEvent};
 use crate::cluster::{
-    bind_task_bias, bind_task_round_robin, BoundTask, ClusterState,
-    ExecutorHeartbeatStream, ExecutorSlot, JobState, JobStateEvent, 
JobStateEventStream,
-    JobStatus, TaskDistribution,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    is_skip_consistent_hash, BoundTask, ClusterState, ExecutorHeartbeatStream,
+    ExecutorSlot, JobState, JobStateEvent, JobStateEventStream, JobStatus,
+    TaskDistributionPolicy, TopologyNode,
 };
 use crate::scheduler_server::{timestamp_secs, SessionBuilder};
 use crate::state::execution_graph::ExecutionGraph;
@@ -28,6 +29,7 @@ use crate::state::task_manager::JobInfoCache;
 use crate::state::{decode_into, decode_protobuf};
 use async_trait::async_trait;
 use ballista_core::config::BallistaConfig;
+use ballista_core::consistent_hash::node::Node;
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::protobuf::job_status::Status;
 use ballista_core::serde::protobuf::{
@@ -37,13 +39,15 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::serde::BallistaCodec;
 use dashmap::DashMap;
+use datafusion::datasource::physical_plan::get_scan_files;
+use datafusion::physical_plan::ExecutionPlan;
 use datafusion::prelude::SessionContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use datafusion_proto::physical_plan::AsExecutionPlan;
 use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
 use futures::StreamExt;
 use itertools::Itertools;
-use log::{info, warn};
+use log::{error, info, warn};
 use prost::Message;
 use std::collections::{HashMap, HashSet};
 use std::future::Future;
@@ -135,6 +139,42 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
             })
             .boxed())
     }
+
+    /// Get the topology nodes of the cluster for consistent hashing
+    fn get_topology_nodes(
+        &self,
+        available_slots: &[AvailableTaskSlots],
+        executors: Option<HashSet<String>>,
+    ) -> HashMap<String, TopologyNode> {
+        let mut nodes: HashMap<String, TopologyNode> = HashMap::new();
+        for slots in available_slots {
+            if let Some(executors) = executors.as_ref() {
+                if !executors.contains(&slots.executor_id) {
+                    continue;
+                }
+            }
+            if let Some(executor) = self.executors.get(&slots.executor_id) {
+                let node = TopologyNode::new(
+                    &executor.host,
+                    executor.port,
+                    &slots.executor_id,
+                    self.executor_heartbeats
+                        .get(&executor.id)
+                        .map(|heartbeat| heartbeat.timestamp)
+                        .unwrap_or(0),
+                    slots.slots,
+                );
+                if let Some(existing_node) = nodes.get(node.name()) {
+                    if existing_node.last_seen_ts < node.last_seen_ts {
+                        nodes.insert(node.name().to_string(), node);
+                    }
+                } else {
+                    nodes.insert(node.name().to_string(), node);
+                }
+            }
+        }
+        nodes
+    }
 }
 
 #[async_trait]
@@ -177,7 +217,7 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
 
     async fn bind_schedulable_tasks(
         &self,
-        distribution: TaskDistribution,
+        distribution: TaskDistributionPolicy,
         active_jobs: Arc<HashMap<String, JobInfoCache>>,
         executors: Option<HashSet<String>>,
     ) -> Result<Vec<BoundTask>> {
@@ -207,12 +247,64 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
                 .collect();
 
             let bound_tasks = match distribution {
-                TaskDistribution::Bias => {
+                TaskDistributionPolicy::Bias => {
                     bind_task_bias(available_slots, active_jobs, |_| 
false).await
                 }
-                TaskDistribution::RoundRobin => {
+                TaskDistributionPolicy::RoundRobin => {
                     bind_task_round_robin(available_slots, active_jobs, |_| 
false).await
                 }
+                TaskDistributionPolicy::ConsistentHash {
+                    num_replicas,
+                    tolerance,
+                } => {
+                    let mut bound_tasks = bind_task_round_robin(
+                        available_slots,
+                        active_jobs.clone(),
+                        |stage_plan: Arc<dyn ExecutionPlan>| {
+                            if let Ok(scan_files) = get_scan_files(stage_plan) 
{
+                                // Should be opposite to consistent hash ones.
+                                !is_skip_consistent_hash(&scan_files)
+                            } else {
+                                false
+                            }
+                        },
+                    )
+                    .await;
+                    info!("{} tasks bound by round robin policy", 
bound_tasks.len());
+                    let (bound_tasks_consistent_hash, ch_topology) =
+                        bind_task_consistent_hash(
+                            self.get_topology_nodes(&slots.task_slots, 
executors),
+                            num_replicas,
+                            tolerance,
+                            active_jobs,
+                            |_, plan| get_scan_files(plan),
+                        )
+                        .await?;
+                    info!(
+                        "{} tasks bound by consistent hashing policy",
+                        bound_tasks_consistent_hash.len()
+                    );
+                    if !bound_tasks_consistent_hash.is_empty() {
+                        bound_tasks.extend(bound_tasks_consistent_hash);
+                        // Update the available slots
+                        let mut executor_data: HashMap<String, 
AvailableTaskSlots> =
+                            slots
+                                .task_slots
+                                .into_iter()
+                                .map(|slots| (slots.executor_id.clone(), 
slots))
+                                .collect();
+                        let ch_topology = ch_topology.unwrap();
+                        for node in ch_topology.nodes() {
+                            if let Some(data) = 
executor_data.get_mut(&node.id) {
+                                data.slots = node.available_slots;
+                            } else {
+                                error!("Fail to find executor data for {}", 
&node.id);
+                            }
+                        }
+                        slots.task_slots = 
executor_data.into_values().collect();
+                    }
+                    bound_tasks
+                }
             };
 
             if !bound_tasks.is_empty() {
diff --git a/ballista/scheduler/src/cluster/memory.rs 
b/ballista/scheduler/src/cluster/memory.rs
index b9dcef40..03a9358d 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -16,8 +16,9 @@
 // under the License.
 
 use crate::cluster::{
-    bind_task_bias, bind_task_round_robin, BoundTask, ClusterState, 
ExecutorSlot,
-    JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistribution,
+    bind_task_bias, bind_task_consistent_hash, bind_task_round_robin,
+    is_skip_consistent_hash, BoundTask, ClusterState, ExecutorSlot, JobState,
+    JobStateEvent, JobStateEventStream, JobStatus, TaskDistributionPolicy, 
TopologyNode,
 };
 use crate::state::execution_graph::ExecutionGraph;
 use async_trait::async_trait;
@@ -36,12 +37,15 @@ use crate::scheduler_server::{timestamp_millis, 
timestamp_secs, SessionBuilder};
 use crate::state::session_manager::create_datafusion_context;
 use crate::state::task_manager::JobInfoCache;
 use ballista_core::serde::protobuf::job_status::Status;
-use log::warn;
+use log::{error, info, warn};
 use std::collections::{HashMap, HashSet};
 use std::ops::DerefMut;
 
+use ballista_core::consistent_hash::node::Node;
+use datafusion::datasource::physical_plan::get_scan_files;
+use datafusion::physical_plan::ExecutionPlan;
 use std::sync::Arc;
-use tokio::sync::Mutex;
+use tokio::sync::{Mutex, MutexGuard};
 use tracing::debug;
 
 #[derive(Default)]
@@ -54,11 +58,49 @@ pub struct InMemoryClusterState {
     heartbeats: DashMap<String, ExecutorHeartbeat>,
 }
 
+impl InMemoryClusterState {
+    /// Get the topology nodes of the cluster for consistent hashing
+    fn get_topology_nodes(
+        &self,
+        guard: &MutexGuard<HashMap<String, AvailableTaskSlots>>,
+        executors: Option<HashSet<String>>,
+    ) -> HashMap<String, TopologyNode> {
+        let mut nodes: HashMap<String, TopologyNode> = HashMap::new();
+        for (executor_id, slots) in guard.iter() {
+            if let Some(executors) = executors.as_ref() {
+                if !executors.contains(executor_id) {
+                    continue;
+                }
+            }
+            if let Some(executor) = self.executors.get(&slots.executor_id) {
+                let node = TopologyNode::new(
+                    &executor.host,
+                    executor.port,
+                    &slots.executor_id,
+                    self.heartbeats
+                        .get(&executor.id)
+                        .map(|heartbeat| heartbeat.timestamp)
+                        .unwrap_or(0),
+                    slots.slots,
+                );
+                if let Some(existing_node) = nodes.get(node.name()) {
+                    if existing_node.last_seen_ts < node.last_seen_ts {
+                        nodes.insert(node.name().to_string(), node);
+                    }
+                } else {
+                    nodes.insert(node.name().to_string(), node);
+                }
+            }
+        }
+        nodes
+    }
+}
+
 #[async_trait]
 impl ClusterState for InMemoryClusterState {
     async fn bind_schedulable_tasks(
         &self,
-        distribution: TaskDistribution,
+        distribution: TaskDistributionPolicy,
         active_jobs: Arc<HashMap<String, JobInfoCache>>,
         executors: Option<HashSet<String>>,
     ) -> Result<Vec<BoundTask>> {
@@ -76,16 +118,61 @@ impl ClusterState for InMemoryClusterState {
             })
             .collect();
 
-        let schedulable_tasks = match distribution {
-            TaskDistribution::Bias => {
+        let bound_tasks = match distribution {
+            TaskDistributionPolicy::Bias => {
                 bind_task_bias(available_slots, active_jobs, |_| false).await
             }
-            TaskDistribution::RoundRobin => {
+            TaskDistributionPolicy::RoundRobin => {
                 bind_task_round_robin(available_slots, active_jobs, |_| 
false).await
             }
+            TaskDistributionPolicy::ConsistentHash {
+                num_replicas,
+                tolerance,
+            } => {
+                let mut bound_tasks = bind_task_round_robin(
+                    available_slots,
+                    active_jobs.clone(),
+                    |stage_plan: Arc<dyn ExecutionPlan>| {
+                        if let Ok(scan_files) = get_scan_files(stage_plan) {
+                            // Should be opposite to consistent hash ones.
+                            !is_skip_consistent_hash(&scan_files)
+                        } else {
+                            false
+                        }
+                    },
+                )
+                .await;
+                info!("{} tasks bound by round robin policy", 
bound_tasks.len());
+                let (bound_tasks_consistent_hash, ch_topology) =
+                    bind_task_consistent_hash(
+                        self.get_topology_nodes(&guard, executors),
+                        num_replicas,
+                        tolerance,
+                        active_jobs,
+                        |_, plan| get_scan_files(plan),
+                    )
+                    .await?;
+                info!(
+                    "{} tasks bound by consistent hashing policy",
+                    bound_tasks_consistent_hash.len()
+                );
+                if !bound_tasks_consistent_hash.is_empty() {
+                    bound_tasks.extend(bound_tasks_consistent_hash);
+                    // Update the available slots
+                    let ch_topology = ch_topology.unwrap();
+                    for node in ch_topology.nodes() {
+                        if let Some(data) = guard.get_mut(&node.id) {
+                            data.slots = node.available_slots;
+                        } else {
+                            error!("Fail to find executor data for {}", 
&node.id);
+                        }
+                    }
+                }
+                bound_tasks
+            }
         };
 
-        Ok(schedulable_tasks)
+        Ok(bound_tasks)
     }
 
     async fn unbind_tasks(&self, executor_slots: Vec<ExecutorSlot>) -> 
Result<()> {
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index c3011184..12938aa1 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -15,25 +15,23 @@
 // specific language governing permissions and limitations
 // under the License.
 
-pub mod event;
-pub mod kv;
-pub mod memory;
-pub mod storage;
+use std::collections::{HashMap, HashSet};
+use std::fmt;
+use std::pin::Pin;
+use std::sync::Arc;
 
-#[cfg(test)]
-#[allow(clippy::uninlined_format_args)]
-pub mod test_util;
+use clap::ArgEnum;
+use datafusion::datasource::listing::PartitionedFile;
+use datafusion::physical_plan::ExecutionPlan;
+use datafusion::prelude::SessionContext;
+use datafusion_proto::logical_plan::AsLogicalPlan;
+use datafusion_proto::physical_plan::AsExecutionPlan;
+use futures::Stream;
+use log::{debug, info, warn};
 
-use crate::cluster::kv::KeyValueState;
-use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState};
-use crate::cluster::storage::etcd::EtcdClient;
-use crate::cluster::storage::sled::SledClient;
-use crate::cluster::storage::KeyValueStore;
-use crate::config::{ClusterStorageConfig, SchedulerConfig, TaskDistribution};
-use crate::scheduler_server::SessionBuilder;
-use crate::state::execution_graph::{create_task_info, ExecutionGraph, 
TaskDescription};
-use crate::state::task_manager::JobInfoCache;
 use ballista_core::config::BallistaConfig;
+use ballista_core::consistent_hash;
+use ballista_core::consistent_hash::ConsistentHash;
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::protobuf::{
     job_status, AvailableTaskSlots, ExecutorHeartbeat, JobStatus,
@@ -41,17 +39,25 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata, 
PartitionId};
 use ballista_core::serde::BallistaCodec;
 use ballista_core::utils::default_session_builder;
-use clap::ArgEnum;
-use datafusion::physical_plan::ExecutionPlan;
-use datafusion::prelude::SessionContext;
-use datafusion_proto::logical_plan::AsLogicalPlan;
-use datafusion_proto::physical_plan::AsExecutionPlan;
-use futures::Stream;
-use log::{debug, info, warn};
-use std::collections::{HashMap, HashSet};
-use std::fmt;
-use std::pin::Pin;
-use std::sync::Arc;
+
+use crate::cluster::kv::KeyValueState;
+use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState};
+use crate::cluster::storage::etcd::EtcdClient;
+use crate::cluster::storage::sled::SledClient;
+use crate::cluster::storage::KeyValueStore;
+use crate::config::{ClusterStorageConfig, SchedulerConfig, 
TaskDistributionPolicy};
+use crate::scheduler_server::SessionBuilder;
+use crate::state::execution_graph::{create_task_info, ExecutionGraph, 
TaskDescription};
+use crate::state::task_manager::JobInfoCache;
+
+pub mod event;
+pub mod kv;
+pub mod memory;
+pub mod storage;
+
+#[cfg(test)]
+#[allow(clippy::uninlined_format_args)]
+pub mod test_util;
 
 // an enum used to configure the backend
 // needs to be visible to code generated by configure_me
@@ -218,7 +224,7 @@ pub trait ClusterState: Send + Sync + 'static {
     /// If `executors` is provided, only bind slots from the specified 
executor IDs
     async fn bind_schedulable_tasks(
         &self,
-        distribution: TaskDistribution,
+        distribution: TaskDistributionPolicy,
         active_jobs: Arc<HashMap<String, JobInfoCache>>,
         executors: Option<HashSet<String>>,
     ) -> Result<Vec<BoundTask>>;
@@ -442,6 +448,7 @@ pub(crate) async fn bind_task_bias(
                     stage_attempt_num: running_stage.stage_attempt_num,
                     task_id,
                     task_attempt: 
running_stage.task_failure_numbers[partition_id],
+                    data_cache: false,
                     plan: running_stage.plan.clone(),
                 };
                 schedulable_tasks.push((executor_id, task_desc));
@@ -530,6 +537,7 @@ pub(crate) async fn bind_task_round_robin(
                     stage_attempt_num: running_stage.stage_attempt_num,
                     task_id,
                     task_attempt: 
running_stage.task_failure_numbers[partition_id],
+                    data_cache: false,
                     plan: running_stage.plan.clone(),
                 };
                 schedulable_tasks.push((executor_id, task_desc));
@@ -547,21 +555,191 @@ pub(crate) async fn bind_task_round_robin(
     schedulable_tasks
 }
 
+type GetScanFilesFunc = fn(
+    &str,
+    Arc<dyn ExecutionPlan>,
+) -> datafusion::common::Result<Vec<Vec<Vec<PartitionedFile>>>>;
+
+pub(crate) async fn bind_task_consistent_hash(
+    topology_nodes: HashMap<String, TopologyNode>,
+    num_replicas: usize,
+    tolerance: usize,
+    active_jobs: Arc<HashMap<String, JobInfoCache>>,
+    get_scan_files: GetScanFilesFunc,
+) -> Result<(Vec<BoundTask>, Option<ConsistentHash<TopologyNode>>)> {
+    let mut total_slots = 0usize;
+    for (_, node) in topology_nodes.iter() {
+        total_slots += node.available_slots as usize;
+    }
+    if total_slots == 0 {
+        info!("Not enough available executor slots for binding tasks with 
consistent hashing policy!!!");
+        return Ok((vec![], None));
+    }
+    info!("Total slot number is {}", total_slots);
+
+    let node_replicas = topology_nodes
+        .into_values()
+        .map(|node| (node, num_replicas))
+        .collect::<Vec<_>>();
+    let mut ch_topology: ConsistentHash<TopologyNode> =
+        ConsistentHash::new(node_replicas);
+
+    let mut schedulable_tasks: Vec<BoundTask> = vec![];
+    for (job_id, job_info) in active_jobs.iter() {
+        if !matches!(job_info.status, Some(job_status::Status::Running(_))) {
+            debug!(
+                "Job {} is not in running status and will be skipped",
+                job_id
+            );
+            continue;
+        }
+        let mut graph = job_info.execution_graph.write().await;
+        let session_id = graph.session_id().to_string();
+        let mut black_list = vec![];
+        while let Some((running_stage, task_id_gen)) =
+            graph.fetch_running_stage(&black_list)
+        {
+            let scan_files = get_scan_files(job_id, 
running_stage.plan.clone())?;
+            if is_skip_consistent_hash(&scan_files) {
+                info!(
+                    "Will skip stage {}/{} for consistent hashing task 
binding",
+                    job_id, running_stage.stage_id
+                );
+                black_list.push(running_stage.stage_id);
+                continue;
+            }
+            let pre_total_slots = total_slots;
+            let scan_files = &scan_files[0];
+            let tolerance_list = vec![0, tolerance];
+            // First round with 0 tolerance consistent hashing policy
+            // Second round with [`tolerance`] tolerance consistent hashing 
policy
+            for tolerance in tolerance_list {
+                let runnable_tasks = running_stage
+                    .task_infos
+                    .iter_mut()
+                    .enumerate()
+                    .filter(|(_partition, info)| info.is_none())
+                    .take(total_slots)
+                    .collect::<Vec<_>>();
+                for (partition_id, task_info) in runnable_tasks {
+                    let partition_files = &scan_files[partition_id];
+                    assert!(!partition_files.is_empty());
+                    // Currently we choose the first file for a task for 
consistent hash.
+                    // Later when splitting files for tasks in datafusion, 
it's better to
+                    // introduce this hash based policy besides the file 
number policy or file size policy.
+                    let file_for_hash = &partition_files[0];
+                    if let Some(node) = ch_topology.get_mut_with_tolerance(
+                        file_for_hash.object_meta.location.as_ref().as_bytes(),
+                        tolerance,
+                    ) {
+                        let executor_id = node.id.clone();
+                        let task_id = *task_id_gen;
+                        *task_id_gen += 1;
+                        *task_info = 
Some(create_task_info(executor_id.clone(), task_id));
+
+                        let partition = PartitionId {
+                            job_id: job_id.clone(),
+                            stage_id: running_stage.stage_id,
+                            partition_id,
+                        };
+                        let data_cache = tolerance == 0;
+                        let task_desc = TaskDescription {
+                            session_id: session_id.clone(),
+                            partition,
+                            stage_attempt_num: running_stage.stage_attempt_num,
+                            task_id,
+                            task_attempt: running_stage.task_failure_numbers
+                                [partition_id],
+                            data_cache,
+                            plan: running_stage.plan.clone(),
+                        };
+                        schedulable_tasks.push((executor_id, task_desc));
+
+                        node.available_slots -= 1;
+                        total_slots -= 1;
+                        if total_slots == 0 {
+                            return Ok((schedulable_tasks, Some(ch_topology)));
+                        }
+                    }
+                }
+            }
+            // Since there's no more tasks from this stage which can be bound,
+            // we should skip this stage at the next round.
+            if pre_total_slots == total_slots {
+                black_list.push(running_stage.stage_id);
+            }
+        }
+    }
+
+    Ok((schedulable_tasks, Some(ch_topology)))
+}
+
+// If if there's no plan which needs to scan files, skip it.
+// Or there are multiple plans which need to scan files for a stage, skip it.
+pub(crate) fn is_skip_consistent_hash(scan_files: 
&[Vec<Vec<PartitionedFile>>]) -> bool {
+    scan_files.is_empty() || scan_files.len() > 1
+}
+
+#[derive(Clone)]
+pub struct TopologyNode {
+    pub id: String,
+    pub name: String,
+    pub last_seen_ts: u64,
+    pub available_slots: u32,
+}
+
+impl TopologyNode {
+    fn new(
+        host: &str,
+        port: u16,
+        id: &str,
+        last_seen_ts: u64,
+        available_slots: u32,
+    ) -> Self {
+        Self {
+            id: id.to_string(),
+            name: format!("{host}:{port}"),
+            last_seen_ts,
+            available_slots,
+        }
+    }
+}
+
+impl consistent_hash::node::Node for TopologyNode {
+    fn name(&self) -> &str {
+        &self.name
+    }
+
+    fn is_valid(&self) -> bool {
+        self.available_slots > 0
+    }
+}
+
 #[cfg(test)]
 mod test {
-    use crate::cluster::{bind_task_bias, bind_task_round_robin, BoundTask};
-    use crate::state::execution_graph::ExecutionGraph;
-    use crate::state::task_manager::JobInfoCache;
-    use crate::test_utils::{mock_completed_task, 
test_aggregation_plan_with_job_id};
+    use std::collections::HashMap;
+    use std::sync::Arc;
+
+    use datafusion::datasource::listing::PartitionedFile;
+    use object_store::path::Path;
+    use object_store::ObjectMeta;
+
     use ballista_core::error::Result;
     use ballista_core::serde::protobuf::AvailableTaskSlots;
     use ballista_core::serde::scheduler::{ExecutorMetadata, 
ExecutorSpecification};
-    use std::collections::HashMap;
-    use std::sync::Arc;
+
+    use crate::cluster::{
+        bind_task_bias, bind_task_consistent_hash, bind_task_round_robin, 
BoundTask,
+        TopologyNode,
+    };
+    use crate::state::execution_graph::ExecutionGraph;
+    use crate::state::task_manager::JobInfoCache;
+    use crate::test_utils::{mock_completed_task, 
test_aggregation_plan_with_job_id};
 
     #[tokio::test]
     async fn test_bind_task_bias() -> Result<()> {
-        let active_jobs = mock_active_jobs().await?;
+        let num_partition = 8usize;
+        let active_jobs = mock_active_jobs(num_partition).await?;
         let mut available_slots = mock_available_slots();
         let available_slots_ref: Vec<&mut AvailableTaskSlots> =
             available_slots.iter_mut().collect();
@@ -613,7 +791,8 @@ mod test {
 
     #[tokio::test]
     async fn test_bind_task_round_robin() -> Result<()> {
-        let active_jobs = mock_active_jobs().await?;
+        let num_partition = 8usize;
+        let active_jobs = mock_active_jobs(num_partition).await?;
         let mut available_slots = mock_available_slots();
         let available_slots_ref: Vec<&mut AvailableTaskSlots> =
             available_slots.iter_mut().collect();
@@ -669,6 +848,104 @@ mod test {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn test_bind_task_consistent_hash() -> Result<()> {
+        let num_partition = 8usize;
+        let active_jobs = mock_active_jobs(num_partition).await?;
+        let active_jobs = Arc::new(active_jobs);
+        let topology_nodes = mock_topology_nodes();
+        let num_replicas = 31;
+        let tolerance = 0;
+
+        // Check none scan files case
+        {
+            let (bound_tasks, _) = bind_task_consistent_hash(
+                topology_nodes.clone(),
+                num_replicas,
+                tolerance,
+                active_jobs.clone(),
+                |_, _| Ok(vec![]),
+            )
+            .await?;
+            assert_eq!(0, bound_tasks.len());
+        }
+
+        // Check job_b with scan files
+        {
+            let (bound_tasks, _) = bind_task_consistent_hash(
+                topology_nodes,
+                num_replicas,
+                tolerance,
+                active_jobs,
+                |job_id, _| mock_get_scan_files("job_b", job_id, 8),
+            )
+            .await?;
+            assert_eq!(6, bound_tasks.len());
+
+            let result = get_result(bound_tasks);
+
+            let mut expected = HashMap::new();
+            {
+                let mut entry_b = HashMap::new();
+                entry_b.insert("executor_3".to_string(), 2);
+                entry_b.insert("executor_2".to_string(), 3);
+                entry_b.insert("executor_1".to_string(), 1);
+
+                expected.insert("job_b".to_string(), entry_b);
+            }
+            assert!(
+                expected.eq(&result),
+                "The result {:?} is not as expected {:?}",
+                result,
+                expected
+            );
+        }
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_bind_task_consistent_hash_with_tolerance() -> Result<()> {
+        let num_partition = 8usize;
+        let active_jobs = mock_active_jobs(num_partition).await?;
+        let active_jobs = Arc::new(active_jobs);
+        let topology_nodes = mock_topology_nodes();
+        let num_replicas = 31;
+        let tolerance = 1;
+
+        {
+            let (bound_tasks, _) = bind_task_consistent_hash(
+                topology_nodes,
+                num_replicas,
+                tolerance,
+                active_jobs,
+                |job_id, _| mock_get_scan_files("job_b", job_id, 8),
+            )
+            .await?;
+            assert_eq!(7, bound_tasks.len());
+
+            let result = get_result(bound_tasks);
+
+            let mut expected = HashMap::new();
+            {
+                let mut entry_b = HashMap::new();
+                entry_b.insert("executor_3".to_string(), 3);
+                entry_b.insert("executor_2".to_string(), 3);
+                entry_b.insert("executor_1".to_string(), 1);
+
+                expected.insert("job_b".to_string(), entry_b);
+            }
+            assert!(
+                expected.eq(&result),
+                "The result {:?} is not as expected {:?}",
+                result,
+                expected
+            );
+        }
+
+        Ok(())
+    }
+
     fn get_result(
         bound_tasks: Vec<BoundTask>,
     ) -> HashMap<String, HashMap<String, usize>> {
@@ -685,9 +962,9 @@ mod test {
         result
     }
 
-    async fn mock_active_jobs() -> Result<HashMap<String, JobInfoCache>> {
-        let num_partition = 8usize;
-
+    async fn mock_active_jobs(
+        num_partition: usize,
+    ) -> Result<HashMap<String, JobInfoCache>> {
         let graph_a = mock_graph("job_a", num_partition, 2).await?;
 
         let graph_b = mock_graph("job_b", num_partition, 7).await?;
@@ -746,4 +1023,51 @@ mod test {
             },
         ]
     }
+
+    fn mock_topology_nodes() -> HashMap<String, TopologyNode> {
+        let mut topology_nodes = HashMap::new();
+        topology_nodes.insert(
+            "executor_1".to_string(),
+            TopologyNode::new("localhost", 8081, "executor_1", 0, 1),
+        );
+        topology_nodes.insert(
+            "executor_2".to_string(),
+            TopologyNode::new("localhost", 8082, "executor_2", 0, 3),
+        );
+        topology_nodes.insert(
+            "executor_3".to_string(),
+            TopologyNode::new("localhost", 8083, "executor_3", 0, 5),
+        );
+        topology_nodes
+    }
+
+    fn mock_get_scan_files(
+        expected_job_id: &str,
+        job_id: &str,
+        num_partition: usize,
+    ) -> datafusion::common::Result<Vec<Vec<Vec<PartitionedFile>>>> {
+        Ok(if expected_job_id.eq(job_id) {
+            mock_scan_files(num_partition)
+        } else {
+            vec![]
+        })
+    }
+
+    fn mock_scan_files(num_partition: usize) -> Vec<Vec<Vec<PartitionedFile>>> 
{
+        let mut scan_files = vec![];
+        for i in 0..num_partition {
+            scan_files.push(vec![PartitionedFile {
+                object_meta: ObjectMeta {
+                    location: Path::from(format!("file--{}", i)),
+                    last_modified: Default::default(),
+                    size: 1,
+                    e_tag: None,
+                },
+                partition_values: vec![],
+                range: None,
+                extensions: None,
+            }]);
+        }
+        vec![scan_files]
+    }
 }
diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index 0cb44a73..7a0c10c4 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -37,7 +37,7 @@ pub struct SchedulerConfig {
     /// The event loop buffer size. for a system of high throughput, a larger 
value like 1000000 is recommended
     pub event_loop_buffer_size: u32,
     /// Policy of distributing tasks to available executor slots. For a 
cluster with single scheduler, round-robin is recommended
-    pub task_distribution: TaskDistribution,
+    pub task_distribution: TaskDistributionPolicy,
     /// The delayed interval for cleaning up finished job data, mainly the 
shuffle data, 0 means the cleaning up is disabled
     pub finished_job_data_clean_up_interval_seconds: u64,
     /// The delayed interval for cleaning up finished job state stored in the 
backend, 0 means the cleaning up is disabled.
@@ -70,7 +70,7 @@ impl Default for SchedulerConfig {
             bind_port: 50050,
             scheduling_policy: TaskSchedulingPolicy::PullStaged,
             event_loop_buffer_size: 10000,
-            task_distribution: TaskDistribution::Bias,
+            task_distribution: TaskDistributionPolicy::Bias,
             finished_job_data_clean_up_interval_seconds: 300,
             finished_job_state_clean_up_interval_seconds: 3600,
             advertise_flight_sql_endpoint: None,
@@ -143,7 +143,7 @@ impl SchedulerConfig {
         self
     }
 
-    pub fn with_task_distribution(mut self, policy: TaskDistribution) -> Self {
+    pub fn with_task_distribution(mut self, policy: TaskDistributionPolicy) -> 
Self {
         self.task_distribution = policy;
         self
     }
@@ -186,9 +186,14 @@ pub enum TaskDistribution {
     /// Eagerly assign tasks to executor slots. This will assign as many task 
slots per executor
     /// as are currently available
     Bias,
-    /// Distributed tasks evenly across executors. This will try and iterate 
through available executors
+    /// Distribute tasks evenly across executors. This will try and iterate 
through available executors
     /// and assign one task to each executor until all tasks are assigned.
     RoundRobin,
+    /// 1. Firstly, try to bind tasks without scanning source files by 
[`RoundRobin`] policy.
+    /// 2. Then for a task for scanning source files, firstly calculate a hash 
value based on input files.
+    /// And then bind it with an execute according to consistent hashing 
policy.
+    /// 3. If needed, work stealing can be enabled based on the tolerance of 
the consistent hashing.
+    ConsistentHash,
 }
 
 impl std::str::FromStr for TaskDistribution {
@@ -204,3 +209,21 @@ impl parse_arg::ParseArgFromStr for TaskDistribution {
         write!(writer, "The executor slots policy for the scheduler")
     }
 }
+
+#[derive(Debug, Clone, Copy)]
+pub enum TaskDistributionPolicy {
+    /// Eagerly assign tasks to executor slots. This will assign as many task 
slots per executor
+    /// as are currently available
+    Bias,
+    /// Distribute tasks evenly across executors. This will try and iterate 
through available executors
+    /// and assign one task to each executor until all tasks are assigned.
+    RoundRobin,
+    /// 1. Firstly, try to bind tasks without scanning source files by 
[`RoundRobin`] policy.
+    /// 2. Then for a task for scanning source files, firstly calculate a hash 
value based on input files.
+    /// And then bind it with an execute according to consistent hashing 
policy.
+    /// 3. If needed, work stealing can be enabled based on the tolerance of 
the consistent hashing.
+    ConsistentHash {
+        num_replicas: usize,
+        tolerance: usize,
+    },
+}
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index f1230c71..8257891c 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -48,7 +48,7 @@ use std::ops::Deref;
 use std::sync::Arc;
 
 use crate::cluster::{bind_task_bias, bind_task_round_robin};
-use crate::config::TaskDistribution;
+use crate::config::TaskDistributionPolicy;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use datafusion::prelude::SessionContext;
 use std::time::{SystemTime, UNIX_EPOCH};
@@ -123,12 +123,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
             let available_slots = available_slots.iter_mut().collect();
             let active_jobs = self.state.task_manager.get_running_job_cache();
             let schedulable_tasks = match self.state.config.task_distribution {
-                TaskDistribution::Bias => {
+                TaskDistributionPolicy::Bias => {
                     bind_task_bias(available_slots, active_jobs, |_| 
false).await
                 }
-                TaskDistribution::RoundRobin => {
+                TaskDistributionPolicy::RoundRobin => {
                     bind_task_round_robin(available_slots, active_jobs, |_| 
false).await
                 }
+                TaskDistributionPolicy::ConsistentHash{..} => {
+                    return Err(Status::unimplemented(
+                        "ConsistentHash TaskDistribution is not feasible for 
pull-based task scheduling"))
+                }
             };
 
             let mut tasks = vec![];
diff --git a/ballista/scheduler/src/state/execution_graph.rs 
b/ballista/scheduler/src/state/execution_graph.rs
index f5ea22d0..ba1bbd2b 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -911,6 +911,7 @@ impl ExecutionGraph {
                     stage_attempt_num: stage.stage_attempt_num,
                     task_id,
                     task_attempt,
+                    data_cache: false,
                     plan: stage.plan.clone(),
                 })
             } else {
@@ -1617,6 +1618,7 @@ pub struct TaskDescription {
     pub stage_attempt_num: usize,
     pub task_id: usize,
     pub task_attempt: usize,
+    pub data_cache: bool,
     pub plan: Arc<dyn ExecutionPlan>,
 }
 
@@ -1625,7 +1627,7 @@ impl Debug for TaskDescription {
         let plan = 
DisplayableExecutionPlan::new(self.plan.as_ref()).indent(false);
         write!(
             f,
-            "TaskDescription[session_id: {},job: {}, stage: {}.{}, partition: 
{} task_id {}, task attempt {}]\n{}",
+            "TaskDescription[session_id: {},job: {}, stage: {}.{}, partition: 
{} task_id {}, task attempt {}, data cache {}]\n{}",
             self.session_id,
             self.partition.job_id,
             self.partition.stage_id,
@@ -1633,6 +1635,7 @@ impl Debug for TaskDescription {
             self.partition.partition_id,
             self.task_id,
             self.task_attempt,
+            self.data_cache,
             plan
         )
     }
diff --git a/ballista/scheduler/src/state/task_manager.rs 
b/ballista/scheduler/src/state/task_manager.rs
index 67f2857b..864bf799 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -27,7 +27,8 @@ use ballista_core::error::Result;
 
 use crate::cluster::JobState;
 use ballista_core::serde::protobuf::{
-    job_status, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, 
TaskStatus,
+    job_status, JobStatus, KeyValuePair, MultiTaskDefinition, TaskDefinition, 
TaskId,
+    TaskStatus,
 };
 use ballista_core::serde::scheduler::ExecutorMetadata;
 use ballista_core::serde::BallistaCodec;
@@ -46,6 +47,7 @@ use std::time::Duration;
 use std::time::{SystemTime, UNIX_EPOCH};
 use tokio::sync::RwLock;
 
+use ballista_core::config::BALLISTA_DATA_CACHE_ENABLED;
 use tracing::trace;
 
 type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>;
@@ -494,6 +496,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
                 plan_buf
             };
 
+            let mut props = vec![];
+            if task.data_cache {
+                props.push(KeyValuePair {
+                    key: BALLISTA_DATA_CACHE_ENABLED.to_string(),
+                    value: "true".to_string(),
+                });
+            }
+
             let task_definition = TaskDefinition {
                 task_id: task.task_id as u32,
                 task_attempt_num: task.task_attempt as u32,
@@ -507,7 +517,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
                     .duration_since(UNIX_EPOCH)
                     .unwrap()
                     .as_millis() as u64,
-                props: vec![],
+                props,
             };
             Ok(task_definition)
         } else {
@@ -524,14 +534,21 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         tasks: Vec<Vec<TaskDescription>>,
         executor_manager: &ExecutorManager,
     ) -> Result<()> {
-        let multi_tasks: Result<Vec<MultiTaskDefinition>> = tasks
-            .into_iter()
-            .map(|stage_tasks| self.prepare_multi_task_definition(stage_tasks))
-            .collect();
+        let mut multi_tasks = vec![];
+        for stage_tasks in tasks {
+            match self.prepare_multi_task_definition(stage_tasks) {
+                Ok(stage_tasks) => multi_tasks.extend(stage_tasks),
+                Err(e) => error!("Fail to prepare task definition: {:?}", e),
+            }
+        }
 
-        self.launcher
-            .launch_tasks(executor, multi_tasks?, executor_manager)
-            .await
+        if !multi_tasks.is_empty() {
+            self.launcher
+                .launch_tasks(executor, multi_tasks, executor_manager)
+                .await
+        } else {
+            Ok(())
+        }
     }
 
     #[allow(dead_code)]
@@ -539,7 +556,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
     fn prepare_multi_task_definition(
         &self,
         tasks: Vec<TaskDescription>,
-    ) -> Result<MultiTaskDefinition> {
+    ) -> Result<Vec<MultiTaskDefinition>> {
         if let Some(task) = tasks.get(0) {
             let session_id = task.session_id.clone();
             let job_id = task.partition.job_id.clone();
@@ -574,29 +591,60 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
                     plan_buf
                 };
 
-                let task_ids = tasks
-                    .iter()
-                    .map(|task| TaskId {
-                        task_id: task.task_id as u32,
-                        task_attempt_num: task.task_attempt as u32,
-                        partition_id: task.partition.partition_id as u32,
-                    })
-                    .collect();
+                let launch_time = SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .unwrap()
+                    .as_millis() as u64;
+
+                let (tasks_with_data_cache, tasks_without_data_cache): 
(Vec<_>, Vec<_>) =
+                    tasks.into_iter().partition(|task| task.data_cache);
+
+                let mut multi_tasks = vec![];
+                if !tasks_with_data_cache.is_empty() {
+                    let task_ids = tasks_with_data_cache
+                        .into_iter()
+                        .map(|task| TaskId {
+                            task_id: task.task_id as u32,
+                            task_attempt_num: task.task_attempt as u32,
+                            partition_id: task.partition.partition_id as u32,
+                        })
+                        .collect();
+                    multi_tasks.push(MultiTaskDefinition {
+                        task_ids,
+                        job_id: job_id.clone(),
+                        stage_id: stage_id as u32,
+                        stage_attempt_num: stage_attempt_num as u32,
+                        plan: plan.clone(),
+                        session_id: session_id.clone(),
+                        launch_time,
+                        props: vec![KeyValuePair {
+                            key: BALLISTA_DATA_CACHE_ENABLED.to_string(),
+                            value: "true".to_string(),
+                        }],
+                    });
+                }
+                if !tasks_without_data_cache.is_empty() {
+                    let task_ids = tasks_without_data_cache
+                        .into_iter()
+                        .map(|task| TaskId {
+                            task_id: task.task_id as u32,
+                            task_attempt_num: task.task_attempt as u32,
+                            partition_id: task.partition.partition_id as u32,
+                        })
+                        .collect();
+                    multi_tasks.push(MultiTaskDefinition {
+                        task_ids,
+                        job_id,
+                        stage_id: stage_id as u32,
+                        stage_attempt_num: stage_attempt_num as u32,
+                        plan,
+                        session_id,
+                        launch_time,
+                        props: vec![],
+                    });
+                }
 
-                let multi_task_definition = MultiTaskDefinition {
-                    task_ids,
-                    job_id,
-                    stage_id: stage_id as u32,
-                    stage_attempt_num: stage_attempt_num as u32,
-                    plan,
-                    session_id,
-                    launch_time: SystemTime::now()
-                        .duration_since(UNIX_EPOCH)
-                        .unwrap()
-                        .as_millis() as u64,
-                    props: vec![],
-                };
-                Ok(multi_task_definition)
+                Ok(multi_tasks)
             } else {
                 Err(BallistaError::General(format!("Cannot prepare multi task 
definition for job {job_id} which is not in active cache")))
             }

Reply via email to