This is an automated email from the ASF dual-hosted git repository.

houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 71757bb  Introduce push-based task scheduling for Ballista (#1560)
71757bb is described below

commit 71757bbd74dd56b800420cfe7a894de7ba882c34
Author: yahoNanJing <[email protected]>
AuthorDate: Sun Jan 23 14:52:35 2022 +0800

    Introduce push-based task scheduling for Ballista (#1560)
    
    * Remove call_ip in the SchedulerServer
    
    * Introduce push-based task scheduling
    
    Co-authored-by: yangzhong <[email protected]>
---
 ballista/rust/core/Cargo.toml                      |   2 +
 ballista/rust/core/proto/ballista.proto            |  94 +++++
 ballista/rust/core/src/config.rs                   |  17 +
 ballista/rust/core/src/serde/scheduler/mod.rs      | 141 +++++++
 ballista/rust/executor/Cargo.toml                  |   1 +
 ballista/rust/executor/executor_config_spec.toml   |  13 +
 ballista/rust/executor/src/execution_loop.rs       |  38 +-
 ballista/rust/executor/src/executor.rs             |  15 +
 ballista/rust/executor/src/executor_server.rs      | 291 ++++++++++++++
 ballista/rust/executor/src/lib.rs                  |  39 ++
 ballista/rust/executor/src/main.rs                 |  65 ++-
 ballista/rust/executor/src/standalone.rs           |   2 +
 ballista/rust/scheduler/scheduler_config_spec.toml |   9 +-
 ballista/rust/scheduler/src/lib.rs                 | 439 +++++++++++++++++++--
 ballista/rust/scheduler/src/main.rs                |  49 ++-
 ballista/rust/scheduler/src/standalone.rs          |   6 +-
 ballista/rust/scheduler/src/state/mod.rs           | 127 +++++-
 17 files changed, 1245 insertions(+), 103 deletions(-)

diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index 060de42..85d0e7a 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -42,6 +42,8 @@ tokio = "1.0"
 tonic = "0.6"
 uuid = { version = "0.8", features = ["v4"] }
 chrono = { version = "0.4", default-features = false }
+clap = "2"
+parse_arg = "0.1.3"
 
 arrow-flight = { version = "7.0.0"  }
 datafusion = { path = "../../../datafusion", version = "6.0.0" }
diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index a0bb841..b3adb36 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -838,6 +838,7 @@ message ExecutorMetadata {
   string id = 1;
   string host = 2;
   uint32 port = 3;
+  uint32 grpc_port = 4;
 }
 
 message ExecutorRegistration {
@@ -848,12 +849,46 @@ message ExecutorRegistration {
     string host = 2;
   }
   uint32 port = 3;
+  uint32 grpc_port = 4;
 }
 
 message ExecutorHeartbeat {
   ExecutorMetadata meta = 1;
   // Unix epoch-based timestamp in seconds
   uint64 timestamp = 2;
+  ExecutorState state = 3;
+}
+
+message ExecutorState {
+  repeated ExecutorMetric metrics = 1;
+}
+
+message ExecutorMetric {
+  // TODO add more metrics
+  oneof metric {
+    uint64 available_memory = 1;
+  }
+}
+
+message ExecutorSpecification {
+  repeated ExecutorResource resources = 1;
+}
+
+message ExecutorResource {
+  // TODO add more resources
+  oneof resource {
+    uint32 task_slots = 1;
+  }
+}
+
+message ExecutorData {
+  string executor_id = 1;
+  repeated ExecutorResourcePair resources = 2;
+}
+
+message ExecutorResourcePair {
+  ExecutorResource total = 1;
+  ExecutorResource available = 2;
 }
 
 message RunningTask {
@@ -906,6 +941,41 @@ message PollWorkResult {
   TaskDefinition task = 1;
 }
 
+message RegisterExecutorParams {
+  ExecutorRegistration metadata = 1;
+  ExecutorSpecification specification = 2;
+}
+
+message RegisterExecutorResult {
+  bool success = 1;
+}
+
+message SendHeartBeatParams {
+  ExecutorRegistration metadata = 1;
+  ExecutorState state = 2;
+}
+
+message SendHeartBeatResult {
+  // TODO it's from Spark for BlockManager
+  bool reregister = 1;
+}
+
+message StopExecutorParams {
+}
+
+message StopExecutorResult {
+}
+
+message UpdateTaskStatusParams {
+  ExecutorRegistration metadata = 1;
+  // All tasks must be reported until they reach the failed or completed state
+  repeated TaskStatus task_status = 2;
+}
+
+message UpdateTaskStatusResult {
+  bool success = 1;
+}
+
 message ExecuteQueryParams {
   oneof query {
     LogicalPlanNode logical_plan = 1;
@@ -965,10 +1035,28 @@ message FilePartitionMetadata {
   repeated string filename = 1;
 }
 
+message LaunchTaskParams {
+  // Allow to launch a task set to an executor at once
+  repeated TaskDefinition task = 1;
+}
+
+message LaunchTaskResult {
+  bool success = 1;
+  // TODO when part of the task set are scheduled successfully
+}
+
 service SchedulerGrpc {
   // Executors must poll the scheduler for heartbeat and to receive tasks
   rpc PollWork (PollWorkParams) returns (PollWorkResult) {}
 
+  rpc RegisterExecutor(RegisterExecutorParams) returns 
(RegisterExecutorResult) {}
+
+  // Push-based task scheduler will only leverage this interface
+  // rather than the PollWork interface to report executor states
+  rpc SendHeartBeat (SendHeartBeatParams) returns (SendHeartBeatResult) {}
+
+  rpc UpdateTaskStatus (UpdateTaskStatusParams) returns 
(UpdateTaskStatusResult) {}
+
   rpc GetFileMetadata (GetFileMetadataParams) returns (GetFileMetadataResult) 
{}
 
   rpc ExecuteQuery (ExecuteQueryParams) returns (ExecuteQueryResult) {}
@@ -976,6 +1064,12 @@ service SchedulerGrpc {
   rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {}
 }
 
+service ExecutorGrpc {
+  rpc LaunchTask (LaunchTaskParams) returns (LaunchTaskResult) {}
+
+  rpc StopExecutor (StopExecutorParams) returns (StopExecutorResult) {}
+}
+
 
///////////////////////////////////////////////////////////////////////////////////////////////////
 // Arrow Data Types
 
///////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/ballista/rust/core/src/config.rs b/ballista/rust/core/src/config.rs
index 2256808..12e50e2 100644
--- a/ballista/rust/core/src/config.rs
+++ b/ballista/rust/core/src/config.rs
@@ -18,6 +18,7 @@
 
 //! Ballista configuration
 
+use clap::arg_enum;
 use core::fmt;
 use std::collections::HashMap;
 use std::result;
@@ -196,6 +197,22 @@ impl BallistaConfig {
     }
 }
 
+// an enum used to configure the scheduler policy
+// needs to be visible to code generated by configure_me
+arg_enum! {
+    #[derive(Clone, Copy, Debug, serde::Deserialize)]
+    pub enum TaskSchedulingPolicy {
+        PullStaged,
+        PushStaged,
+    }
+}
+
+impl parse_arg::ParseArgFromStr for TaskSchedulingPolicy {
+    fn describe_type<W: fmt::Write>(mut writer: W) -> fmt::Result {
+        write!(writer, "The scheduler policy for the scheduler")
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs 
b/ballista/rust/core/src/serde/scheduler/mod.rs
index 8c13c32..43438e2 100644
--- a/ballista/rust/core/src/serde/scheduler/mod.rs
+++ b/ballista/rust/core/src/serde/scheduler/mod.rs
@@ -77,6 +77,7 @@ pub struct ExecutorMeta {
     pub id: String,
     pub host: String,
     pub port: u16,
+    pub grpc_port: u16,
 }
 
 #[allow(clippy::from_over_into)]
@@ -86,6 +87,7 @@ impl Into<protobuf::ExecutorMetadata> for ExecutorMeta {
             id: self.id,
             host: self.host,
             port: self.port as u32,
+            grpc_port: self.grpc_port as u32,
         }
     }
 }
@@ -96,10 +98,149 @@ impl From<protobuf::ExecutorMetadata> for ExecutorMeta {
             id: meta.id,
             host: meta.host,
             port: meta.port as u16,
+            grpc_port: meta.grpc_port as u16,
         }
     }
 }
 
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
+pub struct ExecutorSpecification {
+    pub task_slots: u32,
+}
+
+#[allow(clippy::from_over_into)]
+impl Into<protobuf::ExecutorSpecification> for ExecutorSpecification {
+    fn into(self) -> protobuf::ExecutorSpecification {
+        protobuf::ExecutorSpecification {
+            resources: vec![protobuf::executor_resource::Resource::TaskSlots(
+                self.task_slots,
+            )]
+            .into_iter()
+            .map(|r| protobuf::ExecutorResource { resource: Some(r) })
+            .collect(),
+        }
+    }
+}
+
+impl From<protobuf::ExecutorSpecification> for ExecutorSpecification {
+    fn from(input: protobuf::ExecutorSpecification) -> Self {
+        let mut ret = Self { task_slots: 0 };
+        for resource in input.resources {
+            if let 
Some(protobuf::executor_resource::Resource::TaskSlots(task_slots)) =
+                resource.resource
+            {
+                ret.task_slots = task_slots
+            }
+        }
+        ret
+    }
+}
+
+#[derive(Debug, Clone, Serialize)]
+pub struct ExecutorData {
+    pub executor_id: String,
+    pub total_task_slots: u32,
+    pub available_task_slots: u32,
+}
+
+struct ExecutorResourcePair {
+    total: protobuf::executor_resource::Resource,
+    available: protobuf::executor_resource::Resource,
+}
+
+#[allow(clippy::from_over_into)]
+impl Into<protobuf::ExecutorData> for ExecutorData {
+    fn into(self) -> protobuf::ExecutorData {
+        protobuf::ExecutorData {
+            executor_id: self.executor_id,
+            resources: vec![ExecutorResourcePair {
+                total: protobuf::executor_resource::Resource::TaskSlots(
+                    self.total_task_slots,
+                ),
+                available: protobuf::executor_resource::Resource::TaskSlots(
+                    self.available_task_slots,
+                ),
+            }]
+            .into_iter()
+            .map(|r| protobuf::ExecutorResourcePair {
+                total: Some(protobuf::ExecutorResource {
+                    resource: Some(r.total),
+                }),
+                available: Some(protobuf::ExecutorResource {
+                    resource: Some(r.available),
+                }),
+            })
+            .collect(),
+        }
+    }
+}
+
+impl From<protobuf::ExecutorData> for ExecutorData {
+    fn from(input: protobuf::ExecutorData) -> Self {
+        let mut ret = Self {
+            executor_id: input.executor_id,
+            total_task_slots: 0,
+            available_task_slots: 0,
+        };
+        for resource in input.resources {
+            if let Some(task_slots) = resource.total {
+                if let Some(protobuf::executor_resource::Resource::TaskSlots(
+                    task_slots,
+                )) = task_slots.resource
+                {
+                    ret.total_task_slots = task_slots
+                }
+            };
+            if let Some(task_slots) = resource.available {
+                if let Some(protobuf::executor_resource::Resource::TaskSlots(
+                    task_slots,
+                )) = task_slots.resource
+                {
+                    ret.available_task_slots = task_slots
+                }
+            };
+        }
+        ret
+    }
+}
+
+#[derive(Debug, Clone, Copy, Serialize)]
+pub struct ExecutorState {
+    // in bytes
+    pub available_memory_size: u64,
+}
+
+#[allow(clippy::from_over_into)]
+impl Into<protobuf::ExecutorState> for ExecutorState {
+    fn into(self) -> protobuf::ExecutorState {
+        protobuf::ExecutorState {
+            metrics: vec![protobuf::executor_metric::Metric::AvailableMemory(
+                self.available_memory_size,
+            )]
+            .into_iter()
+            .map(|m| protobuf::ExecutorMetric { metric: Some(m) })
+            .collect(),
+        }
+    }
+}
+
+impl From<protobuf::ExecutorState> for ExecutorState {
+    fn from(input: protobuf::ExecutorState) -> Self {
+        let mut ret = Self {
+            available_memory_size: u64::MAX,
+        };
+        for metric in input.metrics {
+            if let Some(protobuf::executor_metric::Metric::AvailableMemory(
+                available_memory_size,
+            )) = metric.metric
+            {
+                ret.available_memory_size = available_memory_size
+            }
+        }
+        ret
+    }
+}
+
 /// Summary of executed partition
 #[derive(Debug, Copy, Clone, Default)]
 pub struct PartitionStats {
diff --git a/ballista/rust/executor/Cargo.toml 
b/ballista/rust/executor/Cargo.toml
index c01bb20..1f1625f 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -45,6 +45,7 @@ tokio = { version = "1.0", features = ["macros", "rt", 
"rt-multi-thread"] }
 tokio-stream = { version = "0.1", features = ["net"] }
 tonic = "0.6"
 uuid = { version = "0.8", features = ["v4"] }
+hyper = "0.14.4"
 
 [dev-dependencies]
 
diff --git a/ballista/rust/executor/executor_config_spec.toml 
b/ballista/rust/executor/executor_config_spec.toml
index 6f170c8..1dd3de9 100644
--- a/ballista/rust/executor/executor_config_spec.toml
+++ b/ballista/rust/executor/executor_config_spec.toml
@@ -55,6 +55,12 @@ default = "50051"
 doc = "bind port"
 
 [[param]]
+name = "bind_grpc_port"
+type = "u16"
+default = "50052"
+doc = "bind grpc service port"
+
+[[param]]
 name = "work_dir"
 type = "String"
 doc = "Directory for temporary IPC files"
@@ -65,3 +71,10 @@ name = "concurrent_tasks"
 type = "usize"
 default = "4"
 doc = "Max concurrent tasks."
+
+[[param]]
+abbr = "s"
+name = "task_scheduling_policy"
+type = "ballista_core::config::TaskSchedulingPolicy"
+doc = "The task scheduing policy for the scheduler, see 
TaskSchedulingPolicy::variants() for options. Default: PullStaged"
+default = "ballista_core::config::TaskSchedulingPolicy::PullStaged"
diff --git a/ballista/rust/executor/src/execution_loop.rs 
b/ballista/rust/executor/src/execution_loop.rs
index 4d12dfc..69bc838 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -26,12 +26,11 @@ use tonic::transport::Channel;
 
 use ballista_core::serde::protobuf::ExecutorRegistration;
 use ballista_core::serde::protobuf::{
-    self, scheduler_grpc_client::SchedulerGrpcClient, task_status, FailedTask,
-    PartitionId, PollWorkParams, PollWorkResult, ShuffleWritePartition, 
TaskDefinition,
-    TaskStatus,
+    scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
+    TaskDefinition, TaskStatus,
 };
-use protobuf::CompletedTask;
 
+use crate::as_task_status;
 use crate::executor::Executor;
 use ballista_core::error::BallistaError;
 use 
ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
@@ -144,37 +143,6 @@ async fn run_received_tasks(
     Ok(())
 }
 
-fn as_task_status(
-    execution_result: ballista_core::error::Result<Vec<ShuffleWritePartition>>,
-    executor_id: String,
-    task_id: PartitionId,
-) -> TaskStatus {
-    match execution_result {
-        Ok(partitions) => {
-            info!("Task {:?} finished", task_id);
-
-            TaskStatus {
-                partition_id: Some(task_id),
-                status: Some(task_status::Status::Completed(CompletedTask {
-                    executor_id,
-                    partitions,
-                })),
-            }
-        }
-        Err(e) => {
-            let error_msg = e.to_string();
-            info!("Task {:?} failed: {}", task_id, error_msg);
-
-            TaskStatus {
-                partition_id: Some(task_id),
-                status: Some(task_status::Status::Failed(FailedTask {
-                    error: format!("Task failed due to Tokio error: {}", 
error_msg),
-                })),
-            }
-        }
-    }
-}
-
 async fn sample_tasks_status(
     task_status_receiver: &mut Receiver<TaskStatus>,
 ) -> Vec<TaskStatus> {
diff --git a/ballista/rust/executor/src/executor.rs 
b/ballista/rust/executor/src/executor.rs
index ff2f08f..e7479bd 100644
--- a/ballista/rust/executor/src/executor.rs
+++ b/ballista/rust/executor/src/executor.rs
@@ -22,6 +22,7 @@ use std::sync::Arc;
 use ballista_core::error::BallistaError;
 use ballista_core::execution_plans::ShuffleWriterExec;
 use ballista_core::serde::protobuf;
+use ballista_core::serde::scheduler::ExecutorSpecification;
 use datafusion::error::DataFusionError;
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
@@ -31,13 +32,27 @@ use datafusion::physical_plan::{ExecutionPlan, 
Partitioning};
 pub struct Executor {
     /// Directory for storing partial results
     work_dir: String,
+
+    /// Specification like total task slots
+    pub specification: ExecutorSpecification,
 }
 
 impl Executor {
     /// Create a new executor instance
     pub fn new(work_dir: &str) -> Self {
+        Executor::new_with_specification(
+            work_dir,
+            ExecutorSpecification { task_slots: 4 },
+        )
+    }
+
+    pub fn new_with_specification(
+        work_dir: &str,
+        specification: ExecutorSpecification,
+    ) -> Self {
         Self {
             work_dir: work_dir.to_owned(),
+            specification,
         }
     }
 }
diff --git a/ballista/rust/executor/src/executor_server.rs 
b/ballista/rust/executor/src/executor_server.rs
new file mode 100644
index 0000000..3f220ea
--- /dev/null
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -0,0 +1,291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::convert::TryInto;
+use std::sync::Arc;
+use std::time::{Duration, SystemTime, UNIX_EPOCH};
+use tokio::sync::mpsc;
+
+use log::{debug, info};
+use tonic::transport::{Channel, Server};
+use tonic::{Request, Response, Status};
+
+use ballista_core::error::BallistaError;
+use 
ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use ballista_core::serde::protobuf::executor_grpc_server::{
+    ExecutorGrpc, ExecutorGrpcServer,
+};
+use ballista_core::serde::protobuf::executor_registration::OptionalHost;
+use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
+use ballista_core::serde::protobuf::{
+    ExecutorRegistration, LaunchTaskParams, LaunchTaskResult, 
RegisterExecutorParams,
+    SendHeartBeatParams, StopExecutorParams, StopExecutorResult, 
TaskDefinition,
+    UpdateTaskStatusParams,
+};
+use ballista_core::serde::scheduler::{ExecutorSpecification, ExecutorState};
+use datafusion::physical_plan::ExecutionPlan;
+
+use crate::as_task_status;
+use crate::executor::Executor;
+
+pub async fn startup(
+    mut scheduler: SchedulerGrpcClient<Channel>,
+    executor: Arc<Executor>,
+    executor_meta: ExecutorRegistration,
+) {
+    // TODO make the buffer size configurable
+    let (tx_task, rx_task) = mpsc::channel::<TaskDefinition>(1000);
+
+    let executor_server = ExecutorServer::new(
+        scheduler.clone(),
+        executor.clone(),
+        executor_meta.clone(),
+        ExecutorEnv { tx_task },
+    );
+
+    // 1. Start executor grpc service
+    {
+        let executor_meta = executor_meta.clone();
+        let addr = format!(
+            "{}:{}",
+            executor_meta
+                .optional_host
+                .map(|h| match h {
+                    OptionalHost::Host(host) => host,
+                })
+                .unwrap_or_else(|| String::from("127.0.0.1")),
+            executor_meta.grpc_port
+        );
+        let addr = addr.parse().unwrap();
+        info!("Setup executor grpc service for {:?}", addr);
+
+        let server = ExecutorGrpcServer::new(executor_server.clone());
+        let grpc_server_future = 
Server::builder().add_service(server).serve(addr);
+        tokio::spawn(async move { grpc_server_future.await });
+    }
+
+    let executor_server = Arc::new(executor_server);
+
+    // 2. Do executor registration
+    match register_executor(&mut scheduler, &executor_meta, 
&executor.specification).await
+    {
+        Ok(_) => {
+            info!("Executor registration succeed");
+        }
+        Err(error) => {
+            panic!("Executor registration failed due to: {}", error);
+        }
+    };
+
+    // 3. Start Heartbeater
+    {
+        let heartbeater = Heartbeater::new(executor_server.clone());
+        heartbeater.start().await;
+    }
+
+    // 4. Start TaskRunnerPool
+    {
+        let task_runner_pool = TaskRunnerPool::new(executor_server.clone());
+        task_runner_pool.start(rx_task).await;
+    }
+}
+
+#[allow(clippy::clone_on_copy)]
+async fn register_executor(
+    scheduler: &mut SchedulerGrpcClient<Channel>,
+    executor_meta: &ExecutorRegistration,
+    specification: &ExecutorSpecification,
+) -> Result<(), BallistaError> {
+    let result = scheduler
+        .register_executor(RegisterExecutorParams {
+            metadata: Some(executor_meta.clone()),
+            specification: Some(specification.clone().into()),
+        })
+        .await?;
+    if result.into_inner().success {
+        Ok(())
+    } else {
+        Err(BallistaError::General(
+            "Executor registration failed!!!".to_owned(),
+        ))
+    }
+}
+
+#[derive(Clone)]
+pub struct ExecutorServer {
+    _start_time: u128,
+    executor: Arc<Executor>,
+    executor_meta: ExecutorRegistration,
+    scheduler: SchedulerGrpcClient<Channel>,
+    executor_env: ExecutorEnv,
+}
+
+#[derive(Clone)]
+struct ExecutorEnv {
+    tx_task: mpsc::Sender<TaskDefinition>,
+}
+
+unsafe impl Sync for ExecutorEnv {}
+
+impl ExecutorServer {
+    fn new(
+        scheduler: SchedulerGrpcClient<Channel>,
+        executor: Arc<Executor>,
+        executor_meta: ExecutorRegistration,
+        executor_env: ExecutorEnv,
+    ) -> Self {
+        Self {
+            _start_time: SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .unwrap()
+                .as_millis(),
+            executor,
+            executor_meta,
+            scheduler,
+            executor_env,
+        }
+    }
+
+    async fn heartbeat(&self) {
+        // TODO Error handling
+        self.scheduler
+            .clone()
+            .send_heart_beat(SendHeartBeatParams {
+                metadata: Some(self.executor_meta.clone()),
+                state: Some(self.get_executor_state().await.into()),
+            })
+            .await
+            .unwrap();
+    }
+
+    async fn run_task(&self, task: TaskDefinition) -> Result<(), 
BallistaError> {
+        let task_id = task.task_id.unwrap();
+        let task_id_log = format!(
+            "{}/{}/{}",
+            task_id.job_id, task_id.stage_id, task_id.partition_id
+        );
+        info!("Start to run task {}", task_id_log);
+
+        let plan: Arc<dyn ExecutionPlan> = 
(&task.plan.unwrap()).try_into().unwrap();
+        let shuffle_output_partitioning =
+            
parse_protobuf_hash_partitioning(task.output_partitioning.as_ref())?;
+
+        let execution_result = self
+            .executor
+            .execute_shuffle_write(
+                task_id.job_id.clone(),
+                task_id.stage_id as usize,
+                task_id.partition_id as usize,
+                plan,
+                shuffle_output_partitioning,
+            )
+            .await;
+        info!("Done with task {}", task_id_log);
+        debug!("Statistics: {:?}", execution_result);
+
+        // TODO use another channel to update the status of a task set
+        self.scheduler
+            .clone()
+            .update_task_status(UpdateTaskStatusParams {
+                metadata: Some(self.executor_meta.clone()),
+                task_status: vec![as_task_status(
+                    execution_result,
+                    self.executor_meta.id.clone(),
+                    task_id,
+                )],
+            })
+            .await?;
+
+        Ok(())
+    }
+
+    // TODO with real state
+    async fn get_executor_state(&self) -> ExecutorState {
+        ExecutorState {
+            available_memory_size: u64::MAX,
+        }
+    }
+}
+
+struct Heartbeater {
+    executor_server: Arc<ExecutorServer>,
+}
+
+impl Heartbeater {
+    fn new(executor_server: Arc<ExecutorServer>) -> Self {
+        Self { executor_server }
+    }
+
+    async fn start(&self) {
+        let executor_server = self.executor_server.clone();
+        tokio::spawn(async move {
+            info!("Starting heartbeater to send heartbeat the scheduler 
periodically");
+            loop {
+                executor_server.heartbeat().await;
+                tokio::time::sleep(Duration::from_millis(60000)).await;
+            }
+        });
+    }
+}
+
+struct TaskRunnerPool {
+    executor_server: Arc<ExecutorServer>,
+}
+
+impl TaskRunnerPool {
+    fn new(executor_server: Arc<ExecutorServer>) -> Self {
+        Self { executor_server }
+    }
+
+    async fn start(&self, mut rx_task: mpsc::Receiver<TaskDefinition>) {
+        let executor_server = self.executor_server.clone();
+        tokio::spawn(async move {
+            info!("Starting the task runner pool");
+            loop {
+                let task = rx_task.recv().await.unwrap();
+                info!("Received task {:?}", task);
+
+                let server = executor_server.clone();
+                tokio::spawn(async move {
+                    server.run_task(task).await.unwrap();
+                });
+            }
+        });
+    }
+}
+
+#[tonic::async_trait]
+impl ExecutorGrpc for ExecutorServer {
+    async fn launch_task(
+        &self,
+        request: Request<LaunchTaskParams>,
+    ) -> Result<Response<LaunchTaskResult>, Status> {
+        let tasks = request.into_inner().task;
+        let task_sender = self.executor_env.tx_task.clone();
+        for task in tasks {
+            task_sender.send(task).await.unwrap();
+        }
+        Ok(Response::new(LaunchTaskResult { success: true }))
+    }
+
+    async fn stop_executor(
+        &self,
+        _request: Request<StopExecutorParams>,
+    ) -> Result<Response<StopExecutorResult>, Status> {
+        todo!()
+    }
+}
diff --git a/ballista/rust/executor/src/lib.rs 
b/ballista/rust/executor/src/lib.rs
index 714698b..a2711da 100644
--- a/ballista/rust/executor/src/lib.rs
+++ b/ballista/rust/executor/src/lib.rs
@@ -20,7 +20,46 @@
 pub mod collect;
 pub mod execution_loop;
 pub mod executor;
+pub mod executor_server;
 pub mod flight_service;
 
 mod standalone;
 pub use standalone::new_standalone_executor;
+
+use log::info;
+
+use ballista_core::serde::protobuf::{
+    task_status, CompletedTask, FailedTask, PartitionId, ShuffleWritePartition,
+    TaskStatus,
+};
+
+pub fn as_task_status(
+    execution_result: ballista_core::error::Result<Vec<ShuffleWritePartition>>,
+    executor_id: String,
+    task_id: PartitionId,
+) -> TaskStatus {
+    match execution_result {
+        Ok(partitions) => {
+            info!("Task {:?} finished", task_id);
+
+            TaskStatus {
+                partition_id: Some(task_id),
+                status: Some(task_status::Status::Completed(CompletedTask {
+                    executor_id,
+                    partitions,
+                })),
+            }
+        }
+        Err(e) => {
+            let error_msg = e.to_string();
+            info!("Task {:?} failed: {}", task_id, error_msg);
+
+            TaskStatus {
+                partition_id: Some(task_id),
+                status: Some(task_status::Status::Failed(FailedTask {
+                    error: format!("Task failed due to Tokio error: {}", 
error_msg),
+                })),
+            }
+        }
+    }
+}
diff --git a/ballista/rust/executor/src/main.rs 
b/ballista/rust/executor/src/main.rs
index b411a77..2321ce3 100644
--- a/ballista/rust/executor/src/main.rs
+++ b/ballista/rust/executor/src/main.rs
@@ -21,16 +21,18 @@ use std::sync::Arc;
 
 use anyhow::{Context, Result};
 use arrow_flight::flight_service_server::FlightServiceServer;
-use ballista_executor::execution_loop;
+use ballista_executor::{execution_loop, executor_server};
 use log::info;
 use tempfile::TempDir;
 use tonic::transport::Server;
 use uuid::Uuid;
 
+use ballista_core::config::TaskSchedulingPolicy;
 use ballista_core::serde::protobuf::{
     executor_registration, scheduler_grpc_client::SchedulerGrpcClient,
     ExecutorRegistration,
 };
+use ballista_core::serde::scheduler::ExecutorSpecification;
 use ballista_core::{print_version, BALLISTA_VERSION};
 use ballista_executor::executor::Executor;
 use ballista_executor::flight_service::BallistaFlightService;
@@ -67,6 +69,7 @@ async fn main() -> Result<()> {
     let external_host = opt.external_host;
     let bind_host = opt.bind_host;
     let port = opt.bind_port;
+    let grpc_port = opt.bind_grpc_port;
 
     let addr = format!("{}:{}", bind_host, port);
     let addr = addr
@@ -94,32 +97,54 @@ async fn main() -> Result<()> {
             .clone()
             .map(executor_registration::OptionalHost::Host),
         port: port as u32,
+        grpc_port: grpc_port as u32,
     };
+    let executor_specification = ExecutorSpecification {
+        task_slots: opt.concurrent_tasks as u32,
+    };
+    let executor = Arc::new(Executor::new_with_specification(
+        &work_dir,
+        executor_specification,
+    ));
 
     let scheduler = SchedulerGrpcClient::connect(scheduler_url)
         .await
         .context("Could not connect to scheduler")?;
 
-    let executor = Arc::new(Executor::new(&work_dir));
-
-    let service = BallistaFlightService::new(executor.clone());
+    let scheduler_policy = opt.task_scheduling_policy;
+    match scheduler_policy {
+        TaskSchedulingPolicy::PushStaged => {
+            tokio::spawn(executor_server::startup(
+                scheduler,
+                executor.clone(),
+                executor_meta,
+            ));
+        }
+        _ => {
+            tokio::spawn(execution_loop::poll_loop(
+                scheduler,
+                executor.clone(),
+                executor_meta,
+                opt.concurrent_tasks,
+            ));
+        }
+    }
 
-    let server = FlightServiceServer::new(service);
-    info!(
-        "Ballista v{} Rust Executor listening on {:?}",
-        BALLISTA_VERSION, addr
-    );
-    let server_future = 
tokio::spawn(Server::builder().add_service(server).serve(addr));
-    tokio::spawn(execution_loop::poll_loop(
-        scheduler,
-        executor,
-        executor_meta,
-        opt.concurrent_tasks,
-    ));
+    // Arrow flight service
+    {
+        let service = BallistaFlightService::new(executor.clone());
+        let server = FlightServiceServer::new(service);
+        info!(
+            "Ballista v{} Rust Executor listening on {:?}",
+            BALLISTA_VERSION, addr
+        );
+        let server_future =
+            tokio::spawn(Server::builder().add_service(server).serve(addr));
+        server_future
+            .await
+            .context("Tokio error")?
+            .context("Could not start executor server")?;
+    }
 
-    server_future
-        .await
-        .context("Tokio error")?
-        .context("Could not start executor server")?;
     Ok(())
 }
diff --git a/ballista/rust/executor/src/standalone.rs 
b/ballista/rust/executor/src/standalone.rs
index 04174d4..03f47bb 100644
--- a/ballista/rust/executor/src/standalone.rs
+++ b/ballista/rust/executor/src/standalone.rs
@@ -62,6 +62,8 @@ pub async fn new_standalone_executor(
         id: Uuid::new_v4().to_string(), // assign this executor a unique ID
         optional_host: Some(OptionalHost::Host("localhost".to_string())),
         port: addr.port() as u32,
+        // TODO Make it configurable
+        grpc_port: 50020,
     };
     tokio::spawn(execution_loop::poll_loop(
         scheduler,
diff --git a/ballista/rust/scheduler/scheduler_config_spec.toml 
b/ballista/rust/scheduler/scheduler_config_spec.toml
index 81e77d3..cf03fc0 100644
--- a/ballista/rust/scheduler/scheduler_config_spec.toml
+++ b/ballista/rust/scheduler/scheduler_config_spec.toml
@@ -57,4 +57,11 @@ abbr = "p"
 name = "bind_port"
 type = "u16"
 default = "50050"
-doc = "bind port. Default: 50050"
\ No newline at end of file
+doc = "bind port. Default: 50050"
+
+[[param]]
+abbr = "s"
+name = "scheduler_policy"
+type = "ballista_core::config::TaskSchedulingPolicy"
+doc = "The scheduing policy for the scheduler, see 
TaskSchedulingPolicy::variants() for options. Default: PullStaged"
+default = "ballista_core::config::TaskSchedulingPolicy::PullStaged"
\ No newline at end of file
diff --git a/ballista/rust/scheduler/src/lib.rs 
b/ballista/rust/scheduler/src/lib.rs
index 61da0d9..4b2f178 100644
--- a/ballista/rust/scheduler/src/lib.rs
+++ b/ballista/rust/scheduler/src/lib.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-#![doc = include_str!("../README.md")]
+#![doc = include_str ! ("../README.md")]
 
 pub mod api;
 pub mod planner;
@@ -41,21 +41,27 @@ pub mod externalscaler {
     include!(concat!(env!("OUT_DIR"), "/externalscaler.rs"));
 }
 
+use std::collections::{HashMap, HashSet};
+use std::fmt;
 use std::{convert::TryInto, sync::Arc};
-use std::{fmt, net::IpAddr};
 
 use ballista_core::serde::protobuf::{
     execute_query_params::Query, executor_registration::OptionalHost, 
job_status,
     scheduler_grpc_server::SchedulerGrpc, task_status, ExecuteQueryParams,
     ExecuteQueryResult, FailedJob, FileType, GetFileMetadataParams,
     GetFileMetadataResult, GetJobStatusParams, GetJobStatusResult, JobStatus,
-    PartitionId, PollWorkParams, PollWorkResult, QueuedJob, RunningJob, 
TaskDefinition,
-    TaskStatus,
+    LaunchTaskParams, PartitionId, PollWorkParams, PollWorkResult, QueuedJob,
+    RegisterExecutorParams, RegisterExecutorResult, RunningJob, 
SendHeartBeatParams,
+    SendHeartBeatResult, TaskDefinition, TaskStatus, UpdateTaskStatusParams,
+    UpdateTaskStatusResult,
+};
+use ballista_core::serde::scheduler::{
+    ExecutorData, ExecutorMeta, ExecutorSpecification,
 };
-use ballista_core::serde::scheduler::ExecutorMeta;
 
 use clap::arg_enum;
 use datafusion::physical_plan::ExecutionPlan;
+
 #[cfg(feature = "sled")]
 extern crate sled_package as sled;
 
@@ -81,29 +87,51 @@ use crate::externalscaler::{
 };
 use crate::planner::DistributedPlanner;
 
-use log::{debug, error, info, warn};
+use log::{debug, error, info, trace, warn};
 use rand::{distributions::Alphanumeric, thread_rng, Rng};
 use tonic::{Request, Response, Status};
 
 use self::state::{ConfigBackendClient, SchedulerState};
-use ballista_core::config::BallistaConfig;
+use anyhow::Context;
+use ballista_core::config::{BallistaConfig, TaskSchedulingPolicy};
+use ballista_core::error::BallistaError;
 use ballista_core::execution_plans::ShuffleWriterExec;
+use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
 use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
 use datafusion::prelude::{ExecutionConfig, ExecutionContext};
-use std::time::{Instant, SystemTime, UNIX_EPOCH};
+use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
+use tokio::sync::{mpsc, RwLock};
+use tonic::transport::Channel;
 
 #[derive(Clone)]
 pub struct SchedulerServer {
-    caller_ip: IpAddr,
     pub(crate) state: Arc<SchedulerState>,
     start_time: u128,
+    policy: TaskSchedulingPolicy,
+    scheduler_env: Option<SchedulerEnv>,
+    executors_client: Arc<RwLock<HashMap<String, 
ExecutorGrpcClient<Channel>>>>,
+}
+
+#[derive(Clone)]
+pub struct SchedulerEnv {
+    pub tx_job: mpsc::Sender<String>,
 }
 
 impl SchedulerServer {
-    pub fn new(
+    pub fn new(config: Arc<dyn ConfigBackendClient>, namespace: String) -> 
Self {
+        SchedulerServer::new_with_policy(
+            config,
+            namespace,
+            TaskSchedulingPolicy::PullStaged,
+            None,
+        )
+    }
+
+    pub fn new_with_policy(
         config: Arc<dyn ConfigBackendClient>,
         namespace: String,
-        caller_ip: IpAddr,
+        policy: TaskSchedulingPolicy,
+        scheduler_env: Option<SchedulerEnv>,
     ) -> Self {
         let state = Arc::new(SchedulerState::new(config, namespace));
         let state_clone = state.clone();
@@ -112,17 +140,178 @@ impl SchedulerServer {
         tokio::spawn(async move { 
state_clone.synchronize_job_status_loop().await });
 
         Self {
-            caller_ip,
             state,
             start_time: SystemTime::now()
                 .duration_since(UNIX_EPOCH)
                 .unwrap()
                 .as_millis(),
+            policy,
+            scheduler_env,
+            executors_client: Arc::new(RwLock::new(HashMap::new())),
+        }
+    }
+
+    async fn schedule_job(&self, job_id: String) -> Result<(), BallistaError> {
+        let alive_executors = self
+            .state
+            .get_alive_executors_metadata_within_one_minute()
+            .await?;
+        let alive_executors: HashMap<String, ExecutorMeta> = alive_executors
+            .into_iter()
+            .map(|e| (e.id.clone(), e))
+            .collect();
+        let available_executors = 
self.state.get_available_executors_data().await?;
+        let mut available_executors: Vec<ExecutorData> = available_executors
+            .into_iter()
+            .filter(|e| alive_executors.contains_key(&e.executor_id))
+            .collect();
+
+        // In case of there's no enough resources, reschedule the tasks of the 
job
+        if available_executors.is_empty() {
+            let tx_job = self.scheduler_env.as_ref().unwrap().tx_job.clone();
+            // TODO Maybe it's better to use an exclusive runtime for this 
kind task scheduling
+            tokio::spawn(async move {
+                warn!("Not enough available executors for task running");
+                tokio::time::sleep(Duration::from_millis(100)).await;
+                tx_job.send(job_id).await.unwrap();
+            });
+            return Ok(());
+        }
+
+        let (tasks_assigment, num_tasks) =
+            self.fetch_tasks(&mut available_executors, &job_id).await?;
+        if num_tasks > 0 {
+            for (idx_executor, tasks) in 
tasks_assigment.into_iter().enumerate() {
+                if !tasks.is_empty() {
+                    let executor_data = &available_executors[idx_executor];
+                    debug!(
+                        "Start to launch tasks {:?} to executor {:?}",
+                        tasks, executor_data.executor_id
+                    );
+                    let mut client = {
+                        let clients = self.executors_client.read().await;
+                        info!("Size of executor clients: {:?}", clients.len());
+                        
clients.get(&executor_data.executor_id).unwrap().clone()
+                    };
+                    // Update the resources first
+                    
self.state.save_executor_data(executor_data.clone()).await?;
+                    // TODO check whether launching task is successful or not
+                    client.launch_task(LaunchTaskParams { task: tasks 
}).await?;
+                } else {
+                    // Since the task assignment policy is round robin,
+                    // if find tasks for one executor is empty, just break fast
+                    break;
+                }
+            }
+            return Ok(());
+        }
+
+        Ok(())
+    }
+
+    async fn fetch_tasks(
+        &self,
+        available_executors: &mut Vec<ExecutorData>,
+        job_id: &str,
+    ) -> Result<(Vec<Vec<TaskDefinition>>, usize), BallistaError> {
+        let mut ret: Vec<Vec<TaskDefinition>> =
+            Vec::with_capacity(available_executors.len());
+        for _idx in 0..available_executors.len() {
+            ret.push(Vec::new());
+        }
+        let mut num_tasks = 0;
+        loop {
+            info!("Go inside fetching task loop");
+            let mut has_tasks = true;
+            for (idx, executor) in available_executors.iter_mut().enumerate() {
+                if executor.available_task_slots == 0 {
+                    break;
+                }
+                let plan = self
+                    .state
+                    .assign_next_schedulable_job_task(&executor.executor_id, 
job_id)
+                    .await
+                    .map_err(|e| {
+                        let msg = format!("Error finding next assignable task: 
{}", e);
+                        error!("{}", msg);
+                        tonic::Status::internal(msg)
+                    })?;
+                if let Some((task, _plan)) = &plan {
+                    let partition_id = task.partition_id.as_ref().unwrap();
+                    info!(
+                        "Sending new task to {}: {}/{}/{}",
+                        executor.executor_id,
+                        partition_id.job_id,
+                        partition_id.stage_id,
+                        partition_id.partition_id
+                    );
+                }
+                match plan {
+                    Some((status, plan)) => {
+                        let plan_clone = plan.clone();
+                        let output_partitioning = if let Some(shuffle_writer) =
+                            
plan_clone.as_any().downcast_ref::<ShuffleWriterExec>()
+                        {
+                            shuffle_writer.shuffle_output_partitioning()
+                        } else {
+                            return Err(BallistaError::General(format!(
+                                "Task root plan was not a ShuffleWriterExec: 
{:?}",
+                                plan_clone
+                            )));
+                        };
+
+                        ret[idx].push(TaskDefinition {
+                            plan: Some(plan.try_into().unwrap()),
+                            task_id: status.partition_id,
+                            output_partitioning: hash_partitioning_to_proto(
+                                output_partitioning,
+                            )
+                            .map_err(|_| Status::internal("TBD".to_string()))?,
+                        });
+                        executor.available_task_slots -= 1;
+                        num_tasks += 1;
+                    }
+                    _ => {
+                        // Indicate there's no more tasks to be scheduled
+                        has_tasks = false;
+                        break;
+                    }
+                }
+            }
+            if !has_tasks {
+                break;
+            }
+            let has_executors =
+                available_executors.get(0).unwrap().available_task_slots > 0;
+            if !has_executors {
+                break;
+            }
         }
+        Ok((ret, num_tasks))
     }
+}
 
-    pub fn set_caller_ip(&mut self, ip: IpAddr) {
-        self.caller_ip = ip;
+pub struct TaskScheduler {
+    scheduler_server: Arc<SchedulerServer>,
+}
+
+impl TaskScheduler {
+    pub fn new(scheduler_server: Arc<SchedulerServer>) -> Self {
+        Self { scheduler_server }
+    }
+
+    pub fn start(&self, mut rx_job: mpsc::Receiver<String>) {
+        let scheduler_server = self.scheduler_server.clone();
+        tokio::spawn(async move {
+            info!("Starting the task scheduler");
+            loop {
+                let job_id = rx_job.recv().await.unwrap();
+                info!("Fetch job {:?} to be scheduled", job_id.clone());
+
+                let server = scheduler_server.clone();
+                server.schedule_job(job_id).await.unwrap();
+            }
+        });
     }
 }
 
@@ -181,6 +370,13 @@ impl SchedulerGrpc for SchedulerServer {
         &self,
         request: Request<PollWorkParams>,
     ) -> std::result::Result<Response<PollWorkResult>, tonic::Status> {
+        if let TaskSchedulingPolicy::PushStaged = self.policy {
+            error!("Poll work interface is not supported for push-based task 
scheduling");
+            return Err(tonic::Status::failed_precondition(
+                "Bad request because poll work is not supported for push-based 
task scheduling",
+            ));
+        }
+        let remote_addr = request.remote_addr();
         if let PollWorkParams {
             metadata: Some(metadata),
             can_accept_task,
@@ -195,8 +391,9 @@ impl SchedulerGrpc for SchedulerServer {
                     .map(|h| match h {
                         OptionalHost::Host(host) => host,
                     })
-                    .unwrap_or_else(|| self.caller_ip.to_string()),
+                    .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()),
                 port: metadata.port as u16,
+                grpc_port: metadata.grpc_port as u16,
             };
             let mut lock = self.state.lock().await.map_err(|e| {
                 let msg = format!("Could not lock the state: {}", e);
@@ -278,6 +475,195 @@ impl SchedulerGrpc for SchedulerServer {
         }
     }
 
+    async fn register_executor(
+        &self,
+        request: Request<RegisterExecutorParams>,
+    ) -> Result<Response<RegisterExecutorResult>, Status> {
+        let remote_addr = request.remote_addr();
+        if let RegisterExecutorParams {
+            metadata: Some(metadata),
+            specification: Some(specification),
+        } = request.into_inner()
+        {
+            info!("Received register executor request for {:?}", metadata);
+            let metadata: ExecutorMeta = ExecutorMeta {
+                id: metadata.id,
+                host: metadata
+                    .optional_host
+                    .map(|h| match h {
+                        OptionalHost::Host(host) => host,
+                    })
+                    .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()),
+                port: metadata.port as u16,
+                grpc_port: metadata.grpc_port as u16,
+            };
+            // Check whether the executor starts the grpc service
+            {
+                let executor_url =
+                    format!("http://{}:{}";, metadata.host, metadata.grpc_port);
+                info!("Connect to executor {:?}", executor_url);
+                let executor_client = ExecutorGrpcClient::connect(executor_url)
+                    .await
+                    .context("Could not connect to executor")
+                    .map_err(|e| tonic::Status::internal(format!("{:?}", e)))?;
+                let mut clients = self.executors_client.write().await;
+                // TODO check duplicated registration
+                clients.insert(metadata.id.clone(), executor_client);
+                info!("Size of executor clients: {:?}", clients.len());
+            }
+            let mut lock = self.state.lock().await.map_err(|e| {
+                let msg = format!("Could not lock the state: {}", e);
+                error!("{}", msg);
+                tonic::Status::internal(msg)
+            })?;
+            self.state
+                .save_executor_metadata(metadata.clone())
+                .await
+                .map_err(|e| {
+                    let msg = format!("Could not save executor metadata: {}", 
e);
+                    error!("{}", msg);
+                    tonic::Status::internal(msg)
+                })?;
+            let executor_spec: ExecutorSpecification = specification.into();
+            let executor_data = ExecutorData {
+                executor_id: metadata.id.clone(),
+                total_task_slots: executor_spec.task_slots,
+                available_task_slots: executor_spec.task_slots,
+            };
+            self.state
+                .save_executor_data(executor_data)
+                .await
+                .map_err(|e| {
+                    let msg = format!("Could not save executor data: {}", e);
+                    error!("{}", msg);
+                    tonic::Status::internal(msg)
+                })?;
+            lock.unlock().await;
+            Ok(Response::new(RegisterExecutorResult { success: true }))
+        } else {
+            warn!("Received invalid register executor request");
+            Err(tonic::Status::invalid_argument(
+                "Missing metadata in request",
+            ))
+        }
+    }
+
+    async fn send_heart_beat(
+        &self,
+        request: Request<SendHeartBeatParams>,
+    ) -> Result<Response<SendHeartBeatResult>, Status> {
+        let remote_addr = request.remote_addr();
+        if let SendHeartBeatParams {
+            metadata: Some(metadata),
+            state: Some(state),
+        } = request.into_inner()
+        {
+            debug!("Received heart beat request for {:?}", metadata);
+            trace!("Related executor state is {:?}", state);
+            let metadata: ExecutorMeta = ExecutorMeta {
+                id: metadata.id,
+                host: metadata
+                    .optional_host
+                    .map(|h| match h {
+                        OptionalHost::Host(host) => host,
+                    })
+                    .unwrap_or_else(|| remote_addr.unwrap().ip().to_string()),
+                port: metadata.port as u16,
+                grpc_port: metadata.grpc_port as u16,
+            };
+            {
+                let mut lock = self.state.lock().await.map_err(|e| {
+                    let msg = format!("Could not lock the state: {}", e);
+                    error!("{}", msg);
+                    tonic::Status::internal(msg)
+                })?;
+                self.state
+                    .save_executor_state(metadata, Some(state))
+                    .await
+                    .map_err(|e| {
+                        let msg = format!("Could not save executor metadata: 
{}", e);
+                        error!("{}", msg);
+                        tonic::Status::internal(msg)
+                    })?;
+                lock.unlock().await;
+            }
+            Ok(Response::new(SendHeartBeatResult { reregister: false }))
+        } else {
+            warn!("Received invalid executor heart beat request");
+            Err(tonic::Status::invalid_argument(
+                "Missing metadata or metrics in request",
+            ))
+        }
+    }
+
+    async fn update_task_status(
+        &self,
+        request: Request<UpdateTaskStatusParams>,
+    ) -> Result<Response<UpdateTaskStatusResult>, Status> {
+        if let UpdateTaskStatusParams {
+            metadata: Some(metadata),
+            task_status,
+        } = request.into_inner()
+        {
+            debug!("Received task status update request for {:?}", metadata);
+            trace!("Related task status is {:?}", task_status);
+            let mut jobs = HashSet::new();
+            {
+                let mut lock = self.state.lock().await.map_err(|e| {
+                    let msg = format!("Could not lock the state: {}", e);
+                    error!("{}", msg);
+                    tonic::Status::internal(msg)
+                })?;
+                let num_tasks = task_status.len();
+                for task_status in task_status {
+                    self.state
+                        .save_task_status(&task_status)
+                        .await
+                        .map_err(|e| {
+                            let msg = format!("Could not save task status: 
{}", e);
+                            error!("{}", msg);
+                            tonic::Status::internal(msg)
+                        })?;
+                    if task_status.partition_id.is_some() {
+                        
jobs.insert(task_status.partition_id.unwrap().job_id.clone());
+                    }
+                }
+                let mut executor_data = self
+                    .state
+                    .get_executor_data(&metadata.id)
+                    .await
+                    .map_err(|e| {
+                        let msg = format!(
+                            "Could not get metadata data for id {:?}: {}",
+                            &metadata.id, e
+                        );
+                        error!("{}", msg);
+                        tonic::Status::internal(msg)
+                    })?;
+                executor_data.available_task_slots += num_tasks as u32;
+                self.state
+                    .save_executor_data(executor_data)
+                    .await
+                    .map_err(|e| {
+                        let msg = format!("Could not save metadata data: {}", 
e);
+                        error!("{}", msg);
+                        tonic::Status::internal(msg)
+                    })?;
+                lock.unlock().await;
+            }
+            let tx_job = self.scheduler_env.as_ref().unwrap().tx_job.clone();
+            for job_id in jobs {
+                tx_job.send(job_id).await.unwrap();
+            }
+            Ok(Response::new(UpdateTaskStatusResult { success: true }))
+        } else {
+            warn!("Received invalid task status update request");
+            Err(tonic::Status::invalid_argument(
+                "Missing metadata or task status in request",
+            ))
+        }
+    }
+
     async fn get_file_metadata(
         &self,
         request: Request<GetFileMetadataParams>,
@@ -390,6 +776,12 @@ impl SchedulerGrpc for SchedulerServer {
 
             let state = self.state.clone();
             let job_id_spawn = job_id.clone();
+            let tx_job: Option<mpsc::Sender<String>> = match self.policy {
+                TaskSchedulingPolicy::PullStaged => None,
+                TaskSchedulingPolicy::PushStaged => {
+                    Some(self.scheduler_env.as_ref().unwrap().tx_job.clone())
+                }
+            };
             tokio::spawn(async move {
                 // create physical plan using DataFusion
                 let datafusion_ctx = create_datafusion_context(&config);
@@ -503,6 +895,11 @@ impl SchedulerGrpc for SchedulerServer {
                         ));
                     }
                 }
+
+                if let Some(tx_job) = tx_job {
+                    // Send job_id to the scheduler channel
+                    tx_job.send(job_id_spawn).await.unwrap();
+                }
             });
 
             Ok(Response::new(ExecuteQueryResult { job_id }))
@@ -537,10 +934,7 @@ pub fn create_datafusion_context(config: &BallistaConfig) 
-> ExecutionContext {
 
 #[cfg(all(test, feature = "sled"))]
 mod test {
-    use std::{
-        net::{IpAddr, Ipv4Addr},
-        sync::Arc,
-    };
+    use std::sync::Arc;
 
     use tonic::Request;
 
@@ -558,16 +952,13 @@ mod test {
     async fn test_poll_work() -> Result<(), BallistaError> {
         let state = Arc::new(StandaloneClient::try_new_temporary()?);
         let namespace = "default";
-        let scheduler = SchedulerServer::new(
-            state.clone(),
-            namespace.to_owned(),
-            IpAddr::V4(Ipv4Addr::LOCALHOST),
-        );
+        let scheduler = SchedulerServer::new(state.clone(), 
namespace.to_owned());
         let state = SchedulerState::new(state, namespace.to_string());
         let exec_meta = ExecutorRegistration {
             id: "abc".to_owned(),
             optional_host: Some(OptionalHost::Host("".to_owned())),
             port: 0,
+            grpc_port: 0,
         };
         let request: Request<PollWorkParams> = Request::new(PollWorkParams {
             metadata: Some(exec_meta.clone()),
diff --git a/ballista/rust/scheduler/src/main.rs 
b/ballista/rust/scheduler/src/main.rs
index 23a0386..5da5bbe 100644
--- a/ballista/rust/scheduler/src/main.rs
+++ b/ballista/rust/scheduler/src/main.rs
@@ -22,8 +22,8 @@ use 
ballista_scheduler::externalscaler::external_scaler_server::ExternalScalerSe
 use futures::future::{self, Either, TryFutureExt};
 use hyper::{server::conn::AddrStream, service::make_service_fn, Server};
 use std::convert::Infallible;
-use std::net::{IpAddr, Ipv4Addr};
 use std::{net::SocketAddr, sync::Arc};
+use tonic::transport::server::Connected;
 use tonic::transport::Server as TonicServer;
 use tower::Service;
 
@@ -36,9 +36,14 @@ use ballista_scheduler::api::{get_routes, EitherBody, Error};
 use ballista_scheduler::state::EtcdClient;
 #[cfg(feature = "sled")]
 use ballista_scheduler::state::StandaloneClient;
-use ballista_scheduler::{state::ConfigBackendClient, ConfigBackend, 
SchedulerServer};
+use ballista_scheduler::{
+    state::ConfigBackendClient, ConfigBackend, SchedulerEnv, SchedulerServer,
+    TaskScheduler,
+};
 
+use ballista_core::config::TaskSchedulingPolicy;
 use log::info;
+use tokio::sync::mpsc;
 
 #[macro_use]
 extern crate configure_me;
@@ -52,29 +57,43 @@ mod config {
         "/scheduler_configure_me_config.rs"
     ));
 }
+
 use config::prelude::*;
 
 async fn start_server(
     config_backend: Arc<dyn ConfigBackendClient>,
     namespace: String,
     addr: SocketAddr,
+    policy: TaskSchedulingPolicy,
 ) -> Result<()> {
     info!(
         "Ballista v{} Scheduler listening on {:?}",
         BALLISTA_VERSION, addr
     );
     //should only call SchedulerServer::new() once in the process
-    let scheduler_server_without_caller_ip = SchedulerServer::new(
-        config_backend.clone(),
-        namespace.clone(),
-        IpAddr::V4(Ipv4Addr::UNSPECIFIED),
+    info!(
+        "Starting Scheduler grpc server with task scheduling policy of {:?}",
+        policy
     );
+    let scheduler_server = match policy {
+        TaskSchedulingPolicy::PushStaged => {
+            // TODO make the buffer size configurable
+            let (tx_job, rx_job) = mpsc::channel::<String>(10000);
+            let scheduler_server = SchedulerServer::new_with_policy(
+                config_backend.clone(),
+                namespace.clone(),
+                policy,
+                Some(SchedulerEnv { tx_job }),
+            );
+            let task_scheduler = 
TaskScheduler::new(Arc::new(scheduler_server.clone()));
+            task_scheduler.start(rx_job);
+            scheduler_server
+        }
+        _ => SchedulerServer::new(config_backend.clone(), namespace.clone()),
+    };
 
     Ok(Server::bind(&addr)
         .serve(make_service_fn(move |request: &AddrStream| {
-            let mut scheduler_server = 
scheduler_server_without_caller_ip.clone();
-            scheduler_server.set_caller_ip(request.remote_addr().ip());
-
             let scheduler_grpc_server =
                 SchedulerGrpcServer::new(scheduler_server.clone());
 
@@ -84,10 +103,16 @@ async fn start_server(
                 .add_service(scheduler_grpc_server)
                 .add_service(keda_scaler)
                 .into_service();
-            let mut warp = warp::service(get_routes(scheduler_server));
+            let mut warp = warp::service(get_routes(scheduler_server.clone()));
 
+            let connect_info = request.connect_info();
             future::ok::<_, Infallible>(tower::service_fn(
                 move |req: hyper::Request<hyper::Body>| {
+                    // Set the connect info from hyper to tonic
+                    let (mut parts, body) = req.into_parts();
+                    parts.extensions.insert(connect_info.clone());
+                    let req = http::Request::from_parts(parts, body);
+
                     let header = req.headers().get(hyper::header::ACCEPT);
                     if header.is_some() && 
header.unwrap().eq("application/json") {
                         return Either::Left(
@@ -163,6 +188,8 @@ async fn main() -> Result<()> {
             )
         }
     };
-    start_server(client, namespace, addr).await?;
+
+    let policy: TaskSchedulingPolicy = opt.scheduler_policy;
+    start_server(client, namespace, addr, policy).await?;
     Ok(())
 }
diff --git a/ballista/rust/scheduler/src/standalone.rs 
b/ballista/rust/scheduler/src/standalone.rs
index 6ab5bd6..55239d8 100644
--- a/ballista/rust/scheduler/src/standalone.rs
+++ b/ballista/rust/scheduler/src/standalone.rs
@@ -20,10 +20,7 @@ use ballista_core::{
     BALLISTA_VERSION,
 };
 use log::info;
-use std::{
-    net::{IpAddr, Ipv4Addr, SocketAddr},
-    sync::Arc,
-};
+use std::{net::SocketAddr, sync::Arc};
 use tokio::net::TcpListener;
 use tonic::transport::Server;
 
@@ -35,7 +32,6 @@ pub async fn new_standalone_scheduler() -> Result<SocketAddr> 
{
     let server = SchedulerGrpcServer::new(SchedulerServer::new(
         Arc::new(client),
         "ballista".to_string(),
-        IpAddr::V4(Ipv4Addr::LOCALHOST),
     ));
     // Let the OS assign a random, free port
     let listener = TcpListener::bind("localhost:0").await?;
diff --git a/ballista/rust/scheduler/src/state/mod.rs 
b/ballista/rust/scheduler/src/state/mod.rs
index ef6de83..fb45579 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -31,7 +31,7 @@ use ballista_core::serde::protobuf::{
     ExecutorMetadata, FailedJob, FailedTask, JobStatus, PhysicalPlanNode, 
RunningJob,
     RunningTask, TaskStatus,
 };
-use ballista_core::serde::scheduler::PartitionStats;
+use ballista_core::serde::scheduler::{ExecutorData, PartitionStats};
 use ballista_core::{error::BallistaError, serde::scheduler::ExecutorMeta};
 use ballista_core::{error::Result, execution_plans::UnresolvedShuffleExec};
 
@@ -118,6 +118,13 @@ impl SchedulerState {
         Ok(result)
     }
 
+    pub async fn get_alive_executors_metadata_within_one_minute(
+        &self,
+    ) -> Result<Vec<ExecutorMeta>> {
+        self.get_alive_executors_metadata(Duration::from_secs(60))
+            .await
+    }
+
     pub async fn get_alive_executors_metadata(
         &self,
         last_seen_threshold: Duration,
@@ -133,6 +140,14 @@ impl SchedulerState {
     }
 
     pub async fn save_executor_metadata(&self, meta: ExecutorMeta) -> 
Result<()> {
+        self.save_executor_state(meta, None).await
+    }
+
+    pub async fn save_executor_state(
+        &self,
+        meta: ExecutorMeta,
+        state: Option<protobuf::ExecutorState>,
+    ) -> Result<()> {
         let key = get_executor_key(&self.namespace, &meta.id);
         let meta: ExecutorMetadata = meta.into();
         let timestamp = SystemTime::now()
@@ -142,11 +157,57 @@ impl SchedulerState {
         let heartbeat = ExecutorHeartbeat {
             meta: Some(meta),
             timestamp,
+            state,
         };
         let value: Vec<u8> = encode_protobuf(&heartbeat)?;
         self.config_client.put(key, value).await
     }
 
+    pub async fn save_executor_data(&self, executor_data: ExecutorData) -> 
Result<()> {
+        let key = get_executor_data_key(&self.namespace, 
&executor_data.executor_id);
+        let executor_data: protobuf::ExecutorData = executor_data.into();
+        let value: Vec<u8> = encode_protobuf(&executor_data)?;
+        self.config_client.put(key, value).await
+    }
+
+    pub async fn get_executors_data(&self) -> Result<Vec<ExecutorData>> {
+        let mut result = vec![];
+
+        let entries = self
+            .config_client
+            .get_from_prefix(&get_executors_data_prefix(&self.namespace))
+            .await?;
+        for (_key, entry) in entries {
+            let executor_data: protobuf::ExecutorData = 
decode_protobuf(&entry)?;
+            result.push(executor_data.into());
+        }
+        Ok(result)
+    }
+
+    pub async fn get_available_executors_data(&self) -> 
Result<Vec<ExecutorData>> {
+        let mut res = self
+            .get_executors_data()
+            .await?
+            .into_iter()
+            .filter_map(|exec| (exec.available_task_slots > 0).then(|| exec))
+            .collect::<Vec<ExecutorData>>();
+        res.sort_by(|a, b| Ord::cmp(&b.available_task_slots, 
&a.available_task_slots));
+        Ok(res)
+    }
+
+    pub async fn get_executor_data(&self, executor_id: &str) -> 
Result<ExecutorData> {
+        let key = get_executor_data_key(&self.namespace, executor_id);
+        let value = &self.config_client.get(&key).await?;
+        if value.is_empty() {
+            return Err(BallistaError::General(format!(
+                "No executor data found for {}",
+                key
+            )));
+        }
+        let value: protobuf::ExecutorData = decode_protobuf(value)?;
+        Ok(value.into())
+    }
+
     pub async fn save_job_metadata(
         &self,
         job_id: &str,
@@ -233,6 +294,18 @@ impl SchedulerState {
         Ok((&value).try_into()?)
     }
 
+    pub async fn get_job_tasks(
+        &self,
+        job_id: &str,
+    ) -> Result<HashMap<String, TaskStatus>> {
+        self.config_client
+            .get_from_prefix(&get_task_prefix_for_job(&self.namespace, job_id))
+            .await?
+            .into_iter()
+            .map(|(key, bytes)| Ok((key, decode_protobuf(&bytes)?)))
+            .collect()
+    }
+
     pub async fn get_all_tasks(&self) -> Result<HashMap<String, TaskStatus>> {
         self.config_client
             .get_from_prefix(&get_task_prefix(&self.namespace))
@@ -281,6 +354,42 @@ impl SchedulerState {
         executor_id: &str,
     ) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
         let tasks = self.get_all_tasks().await?;
+        self.assign_next_schedulable_task_inner(executor_id, tasks)
+            .await
+    }
+
+    pub async fn assign_next_schedulable_job_task(
+        &self,
+        executor_id: &str,
+        job_id: &str,
+    ) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
+        let job_tasks = self.get_job_tasks(job_id).await?;
+        self.assign_next_schedulable_task_inner(executor_id, job_tasks)
+            .await
+    }
+
+    async fn assign_next_schedulable_task_inner(
+        &self,
+        executor_id: &str,
+        tasks: HashMap<String, TaskStatus>,
+    ) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
+        match self.get_next_schedulable_task(tasks).await? {
+            Some((status, plan)) => {
+                let mut status = status.clone();
+                status.status = Some(task_status::Status::Running(RunningTask {
+                    executor_id: executor_id.to_owned(),
+                }));
+                self.save_task_status(&status).await?;
+                Ok(Some((status, plan)))
+            }
+            _ => Ok(None),
+        }
+    }
+
+    async fn get_next_schedulable_task(
+        &self,
+        tasks: HashMap<String, TaskStatus>,
+    ) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
         // TODO: Make the duration a configurable parameter
         let executors = self
             .get_alive_executors_metadata(Duration::from_secs(60))
@@ -385,12 +494,7 @@ impl SchedulerState {
                     remove_unresolved_shuffles(plan.as_ref(), 
&partition_locations)?;
 
                 // If we get here, there are no more unresolved shuffled and 
the task can be run
-                let mut status = status.clone();
-                status.status = Some(task_status::Status::Running(RunningTask {
-                    executor_id: executor_id.to_owned(),
-                }));
-                self.save_task_status(&status).await?;
-                return Ok(Some((status, plan)));
+                return Ok(Some((status.clone(), plan)));
             }
         }
         Ok(None)
@@ -583,6 +687,14 @@ fn get_executor_key(namespace: &str, id: &str) -> String {
     format!("{}/{}", get_executors_prefix(namespace), id)
 }
 
+fn get_executors_data_prefix(namespace: &str) -> String {
+    format!("/ballista/{}/resources/executors", namespace)
+}
+
+fn get_executor_data_key(namespace: &str, id: &str) -> String {
+    format!("{}/{}", get_executors_data_prefix(namespace), id)
+}
+
 fn get_job_prefix(namespace: &str) -> String {
     format!("/ballista/{}/jobs", namespace)
 }
@@ -670,6 +782,7 @@ mod test {
             id: "123".to_owned(),
             host: "localhost".to_owned(),
             port: 123,
+            grpc_port: 124,
         };
         state.save_executor_metadata(meta.clone()).await?;
         let result: Vec<_> = state

Reply via email to