This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c1f6cc  ARROW-12436: [Rust][Ballista] Add watch capabilities to 
config backend trait
6c1f6cc is described below

commit 6c1f6cce4b3e27c4bfca62aa8afb2345baf2fcfb
Author: Ximo Guanter <[email protected]>
AuthorDate: Sun Apr 18 07:22:05 2021 -0600

    ARROW-12436: [Rust][Ballista] Add watch capabilities to config backend trait
    
    A small next step towards enabling HA in the scheduler. UT + ITs pass.
    
    cc @andygrove
    
    Closes #10085 from edrevo/state-watch
    
    Authored-by: Ximo Guanter <[email protected]>
    Signed-off-by: Andy Grove <[email protected]>
---
 rust/ballista/rust/scheduler/src/lib.rs            |  80 ++---
 rust/ballista/rust/scheduler/src/state/etcd.rs     |  90 +++++-
 rust/ballista/rust/scheduler/src/state/mod.rs      | 348 ++++++++++++---------
 .../rust/scheduler/src/state/standalone.rs         |  73 ++++-
 4 files changed, 377 insertions(+), 214 deletions(-)

diff --git a/rust/ballista/rust/scheduler/src/lib.rs 
b/rust/ballista/rust/scheduler/src/lib.rs
index de49bc0..a675153 100644
--- a/rust/ballista/rust/scheduler/src/lib.rs
+++ b/rust/ballista/rust/scheduler/src/lib.rs
@@ -71,8 +71,7 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH};
 
 #[derive(Clone)]
 pub struct SchedulerServer {
-    state: SchedulerState,
-    namespace: String,
+    state: Arc<SchedulerState>,
     start_time: u128,
     version: String,
 }
@@ -80,10 +79,14 @@ pub struct SchedulerServer {
 impl SchedulerServer {
     pub fn new(config: Arc<dyn ConfigBackendClient>, namespace: String) -> 
Self {
         const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION");
+        let state = Arc::new(SchedulerState::new(config, namespace));
+        let state_clone = state.clone();
+
+        // TODO: we should elect a leader in the scheduler cluster and run 
this only in the leader
+        tokio::spawn(async move { 
state_clone.synchronize_job_status_loop().await });
 
         Self {
-            state: SchedulerState::new(config),
-            namespace,
+            state,
             start_time: SystemTime::now()
                 .duration_since(UNIX_EPOCH)
                 .unwrap()
@@ -102,7 +105,7 @@ impl SchedulerGrpc for SchedulerServer {
         info!("Received get_executors_metadata request");
         let result = self
             .state
-            .get_executors_metadata(self.namespace.as_str())
+            .get_executors_metadata()
             .await
             .map_err(|e| {
                 let msg = format!("Error reading executors metadata: {}", e);
@@ -135,17 +138,16 @@ impl SchedulerGrpc for SchedulerServer {
                 tonic::Status::internal(msg)
             })?;
             self.state
-                .save_executor_metadata(&self.namespace, metadata.clone())
+                .save_executor_metadata(metadata.clone())
                 .await
                 .map_err(|e| {
                     let msg = format!("Could not save executor metadata: {}", 
e);
                     error!("{}", msg);
                     tonic::Status::internal(msg)
                 })?;
-            let task_status_empty = task_status.is_empty();
             for task_status in task_status {
                 self.state
-                    .save_task_status(&self.namespace, &task_status)
+                    .save_task_status(&task_status)
                     .await
                     .map_err(|e| {
                         let msg = format!("Could not save task status: {}", e);
@@ -156,7 +158,7 @@ impl SchedulerGrpc for SchedulerServer {
             let task = if can_accept_task {
                 let plan = self
                     .state
-                    .assign_next_schedulable_task(&self.namespace, 
&metadata.id)
+                    .assign_next_schedulable_task(&metadata.id)
                     .await
                     .map_err(|e| {
                         let msg = format!("Error finding next assignable task: 
{}", e);
@@ -180,12 +182,6 @@ impl SchedulerGrpc for SchedulerServer {
             } else {
                 None
             };
-            // TODO: this should probably happen asynchronously with a watch 
on etc/sled
-            if !task_status_empty {
-                if let Err(e) = 
self.state.synchronize_job_status(&self.namespace).await {
-                    warn!("Could not synchronize jobs and tasks state: {}", e);
-                }
-            }
             lock.unlock().await;
             Ok(Response::new(PollWorkResult { task }))
         } else {
@@ -264,15 +260,11 @@ impl SchedulerGrpc for SchedulerServer {
                 }
             };
             debug!("Received plan for execution: {:?}", plan);
-            let executors = self
-                .state
-                .get_executors_metadata(&self.namespace)
-                .await
-                .map_err(|e| {
-                    let msg = format!("Error reading executors metadata: {}", 
e);
-                    error!("{}", msg);
-                    tonic::Status::internal(msg)
-                })?;
+            let executors = 
self.state.get_executors_metadata().await.map_err(|e| {
+                let msg = format!("Error reading executors metadata: {}", e);
+                error!("{}", msg);
+                tonic::Status::internal(msg)
+            })?;
             debug!("Found executors: {:?}", executors);
 
             let job_id: String = {
@@ -287,7 +279,6 @@ impl SchedulerGrpc for SchedulerServer {
             // Save placeholder job metadata
             self.state
                 .save_job_metadata(
-                    &self.namespace,
                     &job_id,
                     &JobStatus {
                         status: Some(job_status::Status::Queued(QueuedJob {})),
@@ -298,7 +289,6 @@ impl SchedulerGrpc for SchedulerServer {
                     tonic::Status::internal(format!("Could not save job 
metadata: {}", e))
                 })?;
 
-            let namespace = self.namespace.to_owned();
             let state = self.state.clone();
             let job_id_spawn = job_id.clone();
             tokio::spawn(async move {
@@ -311,7 +301,6 @@ impl SchedulerGrpc for SchedulerServer {
                                 warn!("Job {} failed with {}", job_id_spawn, 
error);
                                 state
                                     .save_job_metadata(
-                                        &namespace,
                                         &job_id_spawn,
                                         &JobStatus {
                                             status: 
Some(job_status::Status::Failed(
@@ -358,7 +347,6 @@ impl SchedulerGrpc for SchedulerServer {
                 // create distributed physical plan using Ballista
                 if let Err(e) = state
                     .save_job_metadata(
-                        &namespace,
                         &job_id_spawn,
                         &JobStatus {
                             status: 
Some(job_status::Status::Running(RunningJob {})),
@@ -389,7 +377,6 @@ impl SchedulerGrpc for SchedulerServer {
                 for stage in stages {
                     fail_job!(state
                         .save_stage_plan(
-                            &namespace,
                             &job_id_spawn,
                             stage.stage_id,
                             stage.child.clone()
@@ -410,14 +397,13 @@ impl SchedulerGrpc for SchedulerServer {
                             }),
                             status: None,
                         };
-                        fail_job!(state
-                            .save_task_status(&namespace, &pending_status)
-                            .await
-                            .map_err(|e| {
+                        
fail_job!(state.save_task_status(&pending_status).await.map_err(
+                            |e| {
                                 let msg = format!("Could not save task status: 
{}", e);
                                 error!("{}", msg);
                                 tonic::Status::internal(msg)
-                            }));
+                            }
+                        ));
                     }
                 }
             });
@@ -434,15 +420,11 @@ impl SchedulerGrpc for SchedulerServer {
     ) -> std::result::Result<Response<GetJobStatusResult>, tonic::Status> {
         let job_id = request.into_inner().job_id;
         debug!("Received get_job_status request for job {}", job_id);
-        let job_meta = self
-            .state
-            .get_job_metadata(&self.namespace, &job_id)
-            .await
-            .map_err(|e| {
-                let msg = format!("Error reading job metadata: {}", e);
-                error!("{}", msg);
-                tonic::Status::internal(msg)
-            })?;
+        let job_meta = self.state.get_job_metadata(&job_id).await.map_err(|e| {
+            let msg = format!("Error reading job metadata: {}", e);
+            error!("{}", msg);
+            tonic::Status::internal(msg)
+        })?;
         Ok(Response::new(GetJobStatusResult {
             status: Some(job_meta),
         }))
@@ -468,7 +450,7 @@ mod test {
         let state = Arc::new(StandaloneClient::try_new_temporary()?);
         let namespace = "default";
         let scheduler = SchedulerServer::new(state.clone(), 
namespace.to_owned());
-        let state = SchedulerState::new(state);
+        let state = SchedulerState::new(state, namespace.to_string());
         let exec_meta = ExecutorMetadata {
             id: "abc".to_owned(),
             host: "".to_owned(),
@@ -487,10 +469,7 @@ mod test {
         // no response task since we told the scheduler we didn't want to 
accept one
         assert!(response.task.is_none());
         // executor should be registered
-        assert_eq!(
-            state.get_executors_metadata(namespace).await.unwrap().len(),
-            1
-        );
+        assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
 
         let request: Request<PollWorkParams> = Request::new(PollWorkParams {
             metadata: Some(exec_meta.clone()),
@@ -505,10 +484,7 @@ mod test {
         // still no response task since there are no tasks in the scheduelr
         assert!(response.task.is_none());
         // executor should be registered
-        assert_eq!(
-            state.get_executors_metadata(namespace).await.unwrap().len(),
-            1
-        );
+        assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
         Ok(())
     }
 }
diff --git a/rust/ballista/rust/scheduler/src/state/etcd.rs 
b/rust/ballista/rust/scheduler/src/state/etcd.rs
index ced2461..807477d 100644
--- a/rust/ballista/rust/scheduler/src/state/etcd.rs
+++ b/rust/ballista/rust/scheduler/src/state/etcd.rs
@@ -17,15 +17,18 @@
 
 //! Etcd config backend.
 
-use std::time::Duration;
+use std::{task::Poll, time::Duration};
 
 use crate::state::ConfigBackendClient;
 use ballista_core::error::{ballista_error, Result};
 
-use etcd_client::{GetOptions, LockResponse, PutOptions};
+use etcd_client::{
+    GetOptions, LockResponse, PutOptions, WatchOptions, WatchStream, Watcher,
+};
+use futures::{Stream, StreamExt};
 use log::warn;
 
-use super::Lock;
+use super::{Lock, Watch, WatchEvent};
 
 /// A [`ConfigBackendClient`] implementation that uses etcd to save cluster 
configuration.
 #[derive(Clone)]
@@ -105,6 +108,87 @@ impl ConfigBackendClient for EtcdClient {
             })?;
         Ok(Box::new(EtcdLockGuard { etcd, lock }))
     }
+
+    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+        let mut etcd = self.etcd.clone();
+        let options = WatchOptions::new().with_prefix();
+        let (watcher, stream) = etcd.watch(prefix, 
Some(options)).await.map_err(|e| {
+            warn!("etcd watch failed: {}", e);
+            ballista_error("etcd watch failed")
+        })?;
+        Ok(Box::new(EtcdWatch {
+            watcher,
+            stream,
+            buffered_events: Vec::new(),
+        }))
+    }
+}
+
+struct EtcdWatch {
+    watcher: Watcher,
+    stream: WatchStream,
+    buffered_events: Vec<WatchEvent>,
+}
+
+#[tonic::async_trait]
+impl Watch for EtcdWatch {
+    async fn cancel(&mut self) -> Result<()> {
+        self.watcher.cancel().await.map_err(|e| {
+            warn!("etcd watch cancel failed: {}", e);
+            ballista_error("etcd watch cancel failed")
+        })
+    }
+}
+
+impl Stream for EtcdWatch {
+    type Item = WatchEvent;
+
+    fn poll_next(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let self_mut = self.get_mut();
+        if let Some(event) = self_mut.buffered_events.pop() {
+            Poll::Ready(Some(event))
+        } else {
+            loop {
+                match self_mut.stream.poll_next_unpin(cx) {
+                    Poll::Ready(Some(Err(e))) => {
+                        warn!("Error when watching etcd prefix: {}", e);
+                        continue;
+                    }
+                    Poll::Ready(Some(Ok(v))) => {
+                        
self_mut.buffered_events.extend(v.events().iter().map(|ev| {
+                            match ev.event_type() {
+                                etcd_client::EventType::Put => {
+                                    let kv = ev.kv().unwrap();
+                                    WatchEvent::Put(
+                                        kv.key_str().unwrap().to_string(),
+                                        kv.value().to_owned(),
+                                    )
+                                }
+                                etcd_client::EventType::Delete => {
+                                    let kv = ev.kv().unwrap();
+                                    
WatchEvent::Delete(kv.key_str().unwrap().to_string())
+                                }
+                            }
+                        }));
+                        if let Some(event) = self_mut.buffered_events.pop() {
+                            return Poll::Ready(Some(event));
+                        } else {
+                            continue;
+                        }
+                    }
+                    Poll::Ready(None) => return Poll::Ready(None),
+                    Poll::Pending => return Poll::Pending,
+                }
+            }
+        }
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.stream.size_hint()
+    }
 }
 
 struct EtcdLockGuard {
diff --git a/rust/ballista/rust/scheduler/src/state/mod.rs 
b/rust/ballista/rust/scheduler/src/state/mod.rs
index 794e58f..a15efd6 100644
--- a/rust/ballista/rust/scheduler/src/state/mod.rs
+++ b/rust/ballista/rust/scheduler/src/state/mod.rs
@@ -20,7 +20,8 @@ use std::{
 };
 
 use datafusion::physical_plan::ExecutionPlan;
-use log::{debug, info};
+use futures::{Stream, StreamExt};
+use log::{debug, error, info};
 use prost::Message;
 use tokio::sync::OwnedMutexGuard;
 
@@ -69,27 +70,46 @@ pub trait ConfigBackendClient: Send + Sync {
     ) -> Result<()>;
 
     async fn lock(&self) -> Result<Box<dyn Lock>>;
+
+    /// Watch all events that happen on a specific prefix.
+    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>>;
+}
+
+/// A Watch is a cancelable stream of put or delete events in the 
[ConfigBackendClient]
+#[tonic::async_trait]
+pub trait Watch: Stream<Item = WatchEvent> + Send + Unpin {
+    async fn cancel(&mut self) -> Result<()>;
+}
+
+#[derive(Debug, PartialEq)]
+pub enum WatchEvent {
+    /// Contains the inserted or updated key and the new value
+    Put(String, Vec<u8>),
+
+    /// Contains the deleted key
+    Delete(String),
 }
 
 #[derive(Clone)]
 pub(super) struct SchedulerState {
     config_client: Arc<dyn ConfigBackendClient>,
+    namespace: String,
 }
 
 impl SchedulerState {
-    pub fn new(config_client: Arc<dyn ConfigBackendClient>) -> Self {
-        Self { config_client }
+    pub fn new(config_client: Arc<dyn ConfigBackendClient>, namespace: String) 
-> Self {
+        Self {
+            config_client,
+            namespace,
+        }
     }
 
-    pub async fn get_executors_metadata(
-        &self,
-        namespace: &str,
-    ) -> Result<Vec<ExecutorMeta>> {
+    pub async fn get_executors_metadata(&self) -> Result<Vec<ExecutorMeta>> {
         let mut result = vec![];
 
         let entries = self
             .config_client
-            .get_from_prefix(&get_executors_prefix(namespace))
+            .get_from_prefix(&get_executors_prefix(&self.namespace))
             .await?;
         for (_key, entry) in entries {
             let meta: ExecutorMetadata = decode_protobuf(&entry)?;
@@ -98,12 +118,8 @@ impl SchedulerState {
         Ok(result)
     }
 
-    pub async fn save_executor_metadata(
-        &self,
-        namespace: &str,
-        meta: ExecutorMeta,
-    ) -> Result<()> {
-        let key = get_executor_key(namespace, &meta.id);
+    pub async fn save_executor_metadata(&self, meta: ExecutorMeta) -> 
Result<()> {
+        let key = get_executor_key(&self.namespace, &meta.id);
         let meta: ExecutorMetadata = meta.into();
         let value: Vec<u8> = encode_protobuf(&meta)?;
         self.config_client.put(key, value, Some(LEASE_TIME)).await
@@ -111,22 +127,17 @@ impl SchedulerState {
 
     pub async fn save_job_metadata(
         &self,
-        namespace: &str,
         job_id: &str,
         status: &JobStatus,
     ) -> Result<()> {
         debug!("Saving job metadata: {:?}", status);
-        let key = get_job_key(namespace, job_id);
+        let key = get_job_key(&self.namespace, job_id);
         let value = encode_protobuf(status)?;
         self.config_client.put(key, value, None).await
     }
 
-    pub async fn get_job_metadata(
-        &self,
-        namespace: &str,
-        job_id: &str,
-    ) -> Result<JobStatus> {
-        let key = get_job_key(namespace, job_id);
+    pub async fn get_job_metadata(&self, job_id: &str) -> Result<JobStatus> {
+        let key = get_job_key(&self.namespace, job_id);
         let value = &self.config_client.get(&key).await?;
         if value.is_empty() {
             return Err(BallistaError::General(format!(
@@ -138,14 +149,10 @@ impl SchedulerState {
         Ok(value)
     }
 
-    pub async fn save_task_status(
-        &self,
-        namespace: &str,
-        status: &TaskStatus,
-    ) -> Result<()> {
+    pub async fn save_task_status(&self, status: &TaskStatus) -> Result<()> {
         let partition_id = status.partition_id.as_ref().unwrap();
         let key = get_task_status_key(
-            namespace,
+            &self.namespace,
             &partition_id.job_id,
             partition_id.stage_id as usize,
             partition_id.partition_id as usize,
@@ -156,12 +163,11 @@ impl SchedulerState {
 
     pub async fn _get_task_status(
         &self,
-        namespace: &str,
         job_id: &str,
         stage_id: usize,
         partition_id: usize,
     ) -> Result<TaskStatus> {
-        let key = get_task_status_key(namespace, job_id, stage_id, 
partition_id);
+        let key = get_task_status_key(&self.namespace, job_id, stage_id, 
partition_id);
         let value = &self.config_client.clone().get(&key).await?;
         if value.is_empty() {
             return Err(BallistaError::General(format!(
@@ -176,12 +182,11 @@ impl SchedulerState {
     // "Unnecessary" lifetime syntax due to 
https://github.com/rust-lang/rust/issues/63033
     pub async fn save_stage_plan<'a>(
         &'a self,
-        namespace: &'a str,
         job_id: &'a str,
         stage_id: usize,
         plan: Arc<dyn ExecutionPlan>,
     ) -> Result<()> {
-        let key = get_stage_plan_key(namespace, job_id, stage_id);
+        let key = get_stage_plan_key(&self.namespace, job_id, stage_id);
         let value = {
             let proto: PhysicalPlanNode = plan.try_into()?;
             encode_protobuf(&proto)?
@@ -191,11 +196,10 @@ impl SchedulerState {
 
     pub async fn get_stage_plan(
         &self,
-        namespace: &str,
         job_id: &str,
         stage_id: usize,
     ) -> Result<Arc<dyn ExecutionPlan>> {
-        let key = get_stage_plan_key(namespace, job_id, stage_id);
+        let key = get_stage_plan_key(&self.namespace, job_id, stage_id);
         let value = &self.config_client.get(&key).await?;
         if value.is_empty() {
             return Err(BallistaError::General(format!(
@@ -209,26 +213,21 @@ impl SchedulerState {
 
     pub async fn assign_next_schedulable_task(
         &self,
-        namespace: &str,
         executor_id: &str,
     ) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
         let kvs: HashMap<String, Vec<u8>> = self
             .config_client
-            .get_from_prefix(&get_task_prefix(namespace))
+            .get_from_prefix(&get_task_prefix(&self.namespace))
             .await?
             .into_iter()
             .collect();
-        let executors = self.get_executors_metadata(namespace).await?;
+        let executors = self.get_executors_metadata().await?;
         'tasks: for (_key, value) in kvs.iter() {
             let mut status: TaskStatus = decode_protobuf(&value)?;
             if status.status.is_none() {
                 let partition = status.partition_id.as_ref().unwrap();
                 let plan = self
-                    .get_stage_plan(
-                        namespace,
-                        &partition.job_id,
-                        partition.stage_id as usize,
-                    )
+                    .get_stage_plan(&partition.job_id, partition.stage_id as 
usize)
                     .await?;
 
                 // Let's try to resolve any unresolved shuffles we find
@@ -242,7 +241,7 @@ impl SchedulerState {
                         for partition_id in 
0..unresolved_shuffle.partition_count {
                             let referenced_task = kvs
                                 .get(&get_task_status_key(
-                                    namespace,
+                                    &self.namespace,
                                     &partition.job_id,
                                     stage_id,
                                     partition_id,
@@ -286,7 +285,7 @@ impl SchedulerState {
                 status.status = Some(task_status::Status::Running(RunningTask {
                     executor_id: executor_id.to_owned(),
                 }));
-                self.save_task_status(namespace, &status).await?;
+                self.save_task_status(&status).await?;
                 return Ok(Some((status, plan)));
             }
         }
@@ -298,34 +297,58 @@ impl SchedulerState {
         self.config_client.lock().await
     }
 
-    pub async fn synchronize_job_status(&self, namespace: &str) -> Result<()> {
-        let kvs = self
+    /// This function starts a watch over the task keys. Whenever a task 
changes, it re-evaluates
+    /// the status for the parent job and updates it accordingly.
+    ///
+    /// The future returned by this function never returns (unless an error 
happens), so it is wise
+    /// to [tokio::spawn] calls to this method.
+    pub async fn synchronize_job_status_loop(&self) -> Result<()> {
+        let watch = self
+            .config_client
+            .watch(get_task_prefix(&self.namespace))
+            .await?;
+        watch.for_each(|event: WatchEvent| async {
+            let key = match event {
+                WatchEvent::Put(key, _value) => key,
+                WatchEvent::Delete(key) => key
+            };
+            let job_id = extract_job_id_from_task_key(&key).unwrap();
+            match self.lock().await {
+                Ok(mut lock) => {
+                    if let Err(e) = self.synchronize_job_status(job_id).await {
+                        error!("Could not update job status for {}. This job 
might be stuck forever. Error: {}", job_id, e);
+                    }
+                    lock.unlock().await;
+                },
+                Err(e) => error!("Could not lock config backend. Job {} will 
have an unsynchronized status and might be stuck forever. Error: {}", job_id, e)
+            }
+        }).await;
+
+        Ok(())
+    }
+
+    async fn synchronize_job_status(&self, job_id: &str) -> Result<()> {
+        let value = self
             .config_client
-            .get_from_prefix(&get_job_prefix(namespace))
+            .get(&get_job_key(&self.namespace, job_id))
             .await?;
         let executors: HashMap<String, ExecutorMeta> = self
-            .get_executors_metadata(namespace)
+            .get_executors_metadata()
             .await?
             .into_iter()
             .map(|meta| (meta.id.to_string(), meta))
             .collect();
-        for (key, value) in kvs {
-            let job_id = extract_job_id_from_key(&key)?;
-            let status: JobStatus = decode_protobuf(&value)?;
-            let new_status = self
-                .get_job_status_from_tasks(namespace, job_id, &executors)
-                .await?;
-            if let Some(new_status) = new_status {
-                if status != new_status {
-                    info!(
-                        "Changing status for job {} to {:?}",
-                        job_id, new_status.status
-                    );
-                    debug!("Old status: {:?}", status);
-                    debug!("New status: {:?}", new_status);
-                    self.save_job_metadata(namespace, job_id, &new_status)
-                        .await?;
-                }
+        let status: JobStatus = decode_protobuf(&value)?;
+        let new_status = self.get_job_status_from_tasks(job_id, 
&executors).await?;
+        if let Some(new_status) = new_status {
+            if status != new_status {
+                info!(
+                    "Changing status for job {} to {:?}",
+                    job_id, new_status.status
+                );
+                debug!("Old status: {:?}", status);
+                debug!("New status: {:?}", new_status);
+                self.save_job_metadata(job_id, &new_status).await?;
             }
         }
         Ok(())
@@ -333,13 +356,12 @@ impl SchedulerState {
 
     async fn get_job_status_from_tasks(
         &self,
-        namespace: &str,
         job_id: &str,
         executors: &HashMap<String, ExecutorMeta>,
     ) -> Result<Option<JobStatus>> {
         let statuses = self
             .config_client
-            .get_from_prefix(&get_task_prefix_for_job(namespace, job_id))
+            .get_from_prefix(&get_task_prefix_for_job(&self.namespace, job_id))
             .await?
             .into_iter()
             .map(|(_k, v)| decode_protobuf::<TaskStatus>(&v))
@@ -446,12 +468,6 @@ fn get_job_prefix(namespace: &str) -> String {
     format!("/ballista/{}/jobs", namespace)
 }
 
-fn extract_job_id_from_key(job_key: &str) -> Result<&str> {
-    job_key.split('/').nth(4).ok_or_else(|| {
-        BallistaError::Internal(format!("Unexpected job key: {}", job_key))
-    })
-}
-
 fn get_job_key(namespace: &str, id: &str) -> String {
     format!("{}/{}", get_job_prefix(namespace), id)
 }
@@ -478,6 +494,12 @@ fn get_task_status_key(
     )
 }
 
+fn extract_job_id_from_task_key(job_key: &str) -> Result<&str> {
+    job_key.split('/').nth(4).ok_or_else(|| {
+        BallistaError::Internal(format!("Unexpected task key: {}", job_key))
+    })
+}
+
 fn get_stage_plan_key(namespace: &str, job_id: &str, stage_id: usize) -> 
String {
     format!("/ballista/{}/stages/{}/{}", namespace, job_id, stage_id,)
 }
@@ -514,44 +536,39 @@ mod test {
     };
     use ballista_core::{error::BallistaError, serde::scheduler::ExecutorMeta};
 
-    use super::{SchedulerState, StandaloneClient};
+    use super::{
+        extract_job_id_from_task_key, get_task_status_key, SchedulerState,
+        StandaloneClient,
+    };
 
     #[tokio::test]
     async fn executor_metadata() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let meta = ExecutorMeta {
             id: "123".to_owned(),
             host: "localhost".to_owned(),
             port: 123,
         };
-        state.save_executor_metadata("test", meta.clone()).await?;
-        let result = state.get_executors_metadata("test").await?;
+        state.save_executor_metadata(meta.clone()).await?;
+        let result = state.get_executors_metadata().await?;
         assert_eq!(vec![meta], result);
         Ok(())
     }
 
     #[tokio::test]
-    async fn executor_metadata_empty() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let meta = ExecutorMeta {
-            id: "123".to_owned(),
-            host: "localhost".to_owned(),
-            port: 123,
-        };
-        state.save_executor_metadata("test", meta.clone()).await?;
-        let result = state.get_executors_metadata("test2").await?;
-        assert!(result.is_empty());
-        Ok(())
-    }
-
-    #[tokio::test]
     async fn job_metadata() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let meta = JobStatus {
             status: Some(job_status::Status::Queued(QueuedJob {})),
         };
-        state.save_job_metadata("test", "job", &meta).await?;
-        let result = state.get_job_metadata("test", "job").await?;
+        state.save_job_metadata("job", &meta).await?;
+        let result = state.get_job_metadata("job").await?;
         assert!(result.status.is_some());
         match result.status.unwrap() {
             job_status::Status::Queued(_) => (),
@@ -562,19 +579,25 @@ mod test {
 
     #[tokio::test]
     async fn job_metadata_non_existant() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let meta = JobStatus {
             status: Some(job_status::Status::Queued(QueuedJob {})),
         };
-        state.save_job_metadata("test", "job", &meta).await?;
-        let result = state.get_job_metadata("test2", "job2").await;
+        state.save_job_metadata("job", &meta).await?;
+        let result = state.get_job_metadata("job2").await;
         assert!(result.is_err());
         Ok(())
     }
 
     #[tokio::test]
     async fn task_status() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let meta = TaskStatus {
             status: Some(task_status::Status::Failed(FailedTask {
                 error: "error".to_owned(),
@@ -585,8 +608,8 @@ mod test {
                 partition_id: 2,
             }),
         };
-        state.save_task_status("test", &meta).await?;
-        let result = state._get_task_status("test", "job", 1, 2).await?;
+        state.save_task_status(&meta).await?;
+        let result = state._get_task_status("job", 1, 2).await?;
         assert!(result.status.is_some());
         match result.status.unwrap() {
             task_status::Status::Failed(_) => (),
@@ -597,7 +620,10 @@ mod test {
 
     #[tokio::test]
     async fn task_status_non_existant() -> Result<(), BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let meta = TaskStatus {
             status: Some(task_status::Status::Failed(FailedTask {
                 error: "error".to_owned(),
@@ -608,40 +634,40 @@ mod test {
                 partition_id: 2,
             }),
         };
-        state.save_task_status("test", &meta).await?;
-        let result = state._get_task_status("test", "job", 25, 2).await;
+        state.save_task_status(&meta).await?;
+        let result = state._get_task_status("job", 25, 2).await;
         assert!(result.is_err());
         Ok(())
     }
 
     #[tokio::test]
     async fn task_synchronize_job_status_queued() -> Result<(), BallistaError> 
{
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Queued(QueuedJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_job_metadata(job_id, &job_status).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         assert_eq!(result, job_status);
         Ok(())
     }
 
     #[tokio::test]
     async fn task_synchronize_job_status_running() -> Result<(), 
BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Running(RunningJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
+        state.save_job_metadata(job_id, &job_status).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -652,7 +678,7 @@ mod test {
                 partition_id: 0,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Running(RunningTask {
                 executor_id: "".to_owned(),
@@ -663,24 +689,24 @@ mod test {
                 partition_id: 1,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_task_status(&meta).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         assert_eq!(result, job_status);
         Ok(())
     }
 
     #[tokio::test]
     async fn task_synchronize_job_status_running2() -> Result<(), 
BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Running(RunningJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
+        state.save_job_metadata(job_id, &job_status).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -691,7 +717,7 @@ mod test {
                 partition_id: 0,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: None,
             partition_id: Some(PartitionId {
@@ -700,24 +726,24 @@ mod test {
                 partition_id: 1,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_task_status(&meta).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         assert_eq!(result, job_status);
         Ok(())
     }
 
     #[tokio::test]
     async fn task_synchronize_job_status_completed() -> Result<(), 
BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Running(RunningJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
+        state.save_job_metadata(job_id, &job_status).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -728,7 +754,7 @@ mod test {
                 partition_id: 0,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -739,9 +765,9 @@ mod test {
                 partition_id: 1,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_task_status(&meta).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         match result.status.unwrap() {
             job_status::Status::Completed(_) => (),
             status => panic!("Received status: {:?}", status),
@@ -751,15 +777,15 @@ mod test {
 
     #[tokio::test]
     async fn task_synchronize_job_status_completed2() -> Result<(), 
BallistaError> {
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Queued(QueuedJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
+        state.save_job_metadata(job_id, &job_status).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -770,7 +796,7 @@ mod test {
                 partition_id: 0,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -781,9 +807,9 @@ mod test {
                 partition_id: 1,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_task_status(&meta).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         match result.status.unwrap() {
             job_status::Status::Completed(_) => (),
             status => panic!("Received status: {:?}", status),
@@ -793,15 +819,15 @@ mod test {
 
     #[tokio::test]
     async fn task_synchronize_job_status_failed() -> Result<(), BallistaError> 
{
-        let state = 
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
-        let namespace = "default";
+        let state = SchedulerState::new(
+            Arc::new(StandaloneClient::try_new_temporary()?),
+            "test".to_string(),
+        );
         let job_id = "job";
         let job_status = JobStatus {
             status: Some(job_status::Status::Running(RunningJob {})),
         };
-        state
-            .save_job_metadata(namespace, job_id, &job_status)
-            .await?;
+        state.save_job_metadata(job_id, &job_status).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Completed(CompletedTask {
                 executor_id: "".to_owned(),
@@ -812,7 +838,7 @@ mod test {
                 partition_id: 0,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: Some(task_status::Status::Failed(FailedTask {
                 error: "".to_owned(),
@@ -823,7 +849,7 @@ mod test {
                 partition_id: 1,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
+        state.save_task_status(&meta).await?;
         let meta = TaskStatus {
             status: None,
             partition_id: Some(PartitionId {
@@ -832,13 +858,23 @@ mod test {
                 partition_id: 2,
             }),
         };
-        state.save_task_status(namespace, &meta).await?;
-        state.synchronize_job_status(namespace).await?;
-        let result = state.get_job_metadata(namespace, job_id).await?;
+        state.save_task_status(&meta).await?;
+        state.synchronize_job_status(job_id).await?;
+        let result = state.get_job_metadata(job_id).await?;
         match result.status.unwrap() {
             job_status::Status::Failed(_) => (),
             status => panic!("Received status: {:?}", status),
         }
         Ok(())
     }
+
+    #[test]
+    fn task_extract_job_id_from_task_key() {
+        let job_id = "foo";
+        assert_eq!(
+            extract_job_id_from_task_key(&get_task_status_key("namespace", 
job_id, 0, 1))
+                .unwrap(),
+            job_id
+        );
+    }
 }
diff --git a/rust/ballista/rust/scheduler/src/state/standalone.rs 
b/rust/ballista/rust/scheduler/src/state/standalone.rs
index e07d45e..69805c0 100644
--- a/rust/ballista/rust/scheduler/src/state/standalone.rs
+++ b/rust/ballista/rust/scheduler/src/state/standalone.rs
@@ -15,15 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::{sync::Arc, time::Duration};
+use std::{sync::Arc, task::Poll, time::Duration};
 
 use crate::state::ConfigBackendClient;
 use ballista_core::error::{ballista_error, BallistaError, Result};
 
+use futures::{FutureExt, Stream};
 use log::warn;
+use sled::{Event, Subscriber};
 use tokio::sync::Mutex;
 
-use super::Lock;
+use super::{Lock, Watch, WatchEvent};
 
 /// A [`ConfigBackendClient`] implementation that uses file-based storage to 
save cluster configuration.
 #[derive(Clone)]
@@ -106,13 +108,57 @@ impl ConfigBackendClient for StandaloneClient {
     async fn lock(&self) -> Result<Box<dyn Lock>> {
         Ok(Box::new(self.lock.clone().lock_owned().await))
     }
+
+    async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+        Ok(Box::new(SledWatch {
+            subscriber: self.db.watch_prefix(prefix),
+        }))
+    }
+}
+
+struct SledWatch {
+    subscriber: Subscriber,
+}
+
+#[tonic::async_trait]
+impl Watch for SledWatch {
+    async fn cancel(&mut self) -> Result<()> {
+        Ok(())
+    }
+}
+
+impl Stream for SledWatch {
+    type Item = WatchEvent;
+
+    fn poll_next(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Option<Self::Item>> {
+        match self.get_mut().subscriber.poll_unpin(cx) {
+            Poll::Pending => Poll::Pending,
+            Poll::Ready(None) => Poll::Ready(None),
+            Poll::Ready(Some(Event::Insert { key, value })) => {
+                let key = std::str::from_utf8(&key).unwrap().to_owned();
+                Poll::Ready(Some(WatchEvent::Put(key, value.to_vec())))
+            }
+            Poll::Ready(Some(Event::Remove { key })) => {
+                let key = std::str::from_utf8(&key).unwrap().to_owned();
+                Poll::Ready(Some(WatchEvent::Delete(key)))
+            }
+        }
+    }
+
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        self.subscriber.size_hint()
+    }
 }
 
 #[cfg(test)]
 mod tests {
-    use crate::state::ConfigBackendClient;
+    use crate::state::{ConfigBackendClient, Watch, WatchEvent};
 
     use super::StandaloneClient;
+    use futures::StreamExt;
     use std::result::Result;
 
     fn create_instance() -> Result<StandaloneClient, Box<dyn 
std::error::Error>> {
@@ -158,4 +204,25 @@ mod tests {
         );
         Ok(())
     }
+
+    #[tokio::test]
+    async fn read_watch() -> Result<(), Box<dyn std::error::Error>> {
+        let client = create_instance()?;
+        let key = "key";
+        let value = "value".as_bytes();
+        let mut watch: Box<dyn Watch> = client.watch(key.to_owned()).await?;
+        client.put(key.to_owned(), value.to_vec(), None).await?;
+        assert_eq!(
+            watch.next().await,
+            Some(WatchEvent::Put(key.to_owned(), value.to_owned()))
+        );
+        let value2 = "value2".as_bytes();
+        client.put(key.to_owned(), value2.to_vec(), None).await?;
+        assert_eq!(
+            watch.next().await,
+            Some(WatchEvent::Put(key.to_owned(), value2.to_owned()))
+        );
+        watch.cancel().await?;
+        Ok(())
+    }
 }

Reply via email to