This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 6c1f6cc ARROW-12436: [Rust][Ballista] Add watch capabilities to
config backend trait
6c1f6cc is described below
commit 6c1f6cce4b3e27c4bfca62aa8afb2345baf2fcfb
Author: Ximo Guanter <[email protected]>
AuthorDate: Sun Apr 18 07:22:05 2021 -0600
ARROW-12436: [Rust][Ballista] Add watch capabilities to config backend trait
A small next step towards enabling HA in the scheduler. UT + ITs pass.
cc @andygrove
Closes #10085 from edrevo/state-watch
Authored-by: Ximo Guanter <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
---
rust/ballista/rust/scheduler/src/lib.rs | 80 ++---
rust/ballista/rust/scheduler/src/state/etcd.rs | 90 +++++-
rust/ballista/rust/scheduler/src/state/mod.rs | 348 ++++++++++++---------
.../rust/scheduler/src/state/standalone.rs | 73 ++++-
4 files changed, 377 insertions(+), 214 deletions(-)
diff --git a/rust/ballista/rust/scheduler/src/lib.rs
b/rust/ballista/rust/scheduler/src/lib.rs
index de49bc0..a675153 100644
--- a/rust/ballista/rust/scheduler/src/lib.rs
+++ b/rust/ballista/rust/scheduler/src/lib.rs
@@ -71,8 +71,7 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH};
#[derive(Clone)]
pub struct SchedulerServer {
- state: SchedulerState,
- namespace: String,
+ state: Arc<SchedulerState>,
start_time: u128,
version: String,
}
@@ -80,10 +79,14 @@ pub struct SchedulerServer {
impl SchedulerServer {
pub fn new(config: Arc<dyn ConfigBackendClient>, namespace: String) ->
Self {
const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION");
+ let state = Arc::new(SchedulerState::new(config, namespace));
+ let state_clone = state.clone();
+
+ // TODO: we should elect a leader in the scheduler cluster and run
this only in the leader
+ tokio::spawn(async move {
state_clone.synchronize_job_status_loop().await });
Self {
- state: SchedulerState::new(config),
- namespace,
+ state,
start_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
@@ -102,7 +105,7 @@ impl SchedulerGrpc for SchedulerServer {
info!("Received get_executors_metadata request");
let result = self
.state
- .get_executors_metadata(self.namespace.as_str())
+ .get_executors_metadata()
.await
.map_err(|e| {
let msg = format!("Error reading executors metadata: {}", e);
@@ -135,17 +138,16 @@ impl SchedulerGrpc for SchedulerServer {
tonic::Status::internal(msg)
})?;
self.state
- .save_executor_metadata(&self.namespace, metadata.clone())
+ .save_executor_metadata(metadata.clone())
.await
.map_err(|e| {
let msg = format!("Could not save executor metadata: {}",
e);
error!("{}", msg);
tonic::Status::internal(msg)
})?;
- let task_status_empty = task_status.is_empty();
for task_status in task_status {
self.state
- .save_task_status(&self.namespace, &task_status)
+ .save_task_status(&task_status)
.await
.map_err(|e| {
let msg = format!("Could not save task status: {}", e);
@@ -156,7 +158,7 @@ impl SchedulerGrpc for SchedulerServer {
let task = if can_accept_task {
let plan = self
.state
- .assign_next_schedulable_task(&self.namespace,
&metadata.id)
+ .assign_next_schedulable_task(&metadata.id)
.await
.map_err(|e| {
let msg = format!("Error finding next assignable task:
{}", e);
@@ -180,12 +182,6 @@ impl SchedulerGrpc for SchedulerServer {
} else {
None
};
- // TODO: this should probably happen asynchronously with a watch
on etc/sled
- if !task_status_empty {
- if let Err(e) =
self.state.synchronize_job_status(&self.namespace).await {
- warn!("Could not synchronize jobs and tasks state: {}", e);
- }
- }
lock.unlock().await;
Ok(Response::new(PollWorkResult { task }))
} else {
@@ -264,15 +260,11 @@ impl SchedulerGrpc for SchedulerServer {
}
};
debug!("Received plan for execution: {:?}", plan);
- let executors = self
- .state
- .get_executors_metadata(&self.namespace)
- .await
- .map_err(|e| {
- let msg = format!("Error reading executors metadata: {}",
e);
- error!("{}", msg);
- tonic::Status::internal(msg)
- })?;
+ let executors =
self.state.get_executors_metadata().await.map_err(|e| {
+ let msg = format!("Error reading executors metadata: {}", e);
+ error!("{}", msg);
+ tonic::Status::internal(msg)
+ })?;
debug!("Found executors: {:?}", executors);
let job_id: String = {
@@ -287,7 +279,6 @@ impl SchedulerGrpc for SchedulerServer {
// Save placeholder job metadata
self.state
.save_job_metadata(
- &self.namespace,
&job_id,
&JobStatus {
status: Some(job_status::Status::Queued(QueuedJob {})),
@@ -298,7 +289,6 @@ impl SchedulerGrpc for SchedulerServer {
tonic::Status::internal(format!("Could not save job
metadata: {}", e))
})?;
- let namespace = self.namespace.to_owned();
let state = self.state.clone();
let job_id_spawn = job_id.clone();
tokio::spawn(async move {
@@ -311,7 +301,6 @@ impl SchedulerGrpc for SchedulerServer {
warn!("Job {} failed with {}", job_id_spawn,
error);
state
.save_job_metadata(
- &namespace,
&job_id_spawn,
&JobStatus {
status:
Some(job_status::Status::Failed(
@@ -358,7 +347,6 @@ impl SchedulerGrpc for SchedulerServer {
// create distributed physical plan using Ballista
if let Err(e) = state
.save_job_metadata(
- &namespace,
&job_id_spawn,
&JobStatus {
status:
Some(job_status::Status::Running(RunningJob {})),
@@ -389,7 +377,6 @@ impl SchedulerGrpc for SchedulerServer {
for stage in stages {
fail_job!(state
.save_stage_plan(
- &namespace,
&job_id_spawn,
stage.stage_id,
stage.child.clone()
@@ -410,14 +397,13 @@ impl SchedulerGrpc for SchedulerServer {
}),
status: None,
};
- fail_job!(state
- .save_task_status(&namespace, &pending_status)
- .await
- .map_err(|e| {
+
fail_job!(state.save_task_status(&pending_status).await.map_err(
+ |e| {
let msg = format!("Could not save task status:
{}", e);
error!("{}", msg);
tonic::Status::internal(msg)
- }));
+ }
+ ));
}
}
});
@@ -434,15 +420,11 @@ impl SchedulerGrpc for SchedulerServer {
) -> std::result::Result<Response<GetJobStatusResult>, tonic::Status> {
let job_id = request.into_inner().job_id;
debug!("Received get_job_status request for job {}", job_id);
- let job_meta = self
- .state
- .get_job_metadata(&self.namespace, &job_id)
- .await
- .map_err(|e| {
- let msg = format!("Error reading job metadata: {}", e);
- error!("{}", msg);
- tonic::Status::internal(msg)
- })?;
+ let job_meta = self.state.get_job_metadata(&job_id).await.map_err(|e| {
+ let msg = format!("Error reading job metadata: {}", e);
+ error!("{}", msg);
+ tonic::Status::internal(msg)
+ })?;
Ok(Response::new(GetJobStatusResult {
status: Some(job_meta),
}))
@@ -468,7 +450,7 @@ mod test {
let state = Arc::new(StandaloneClient::try_new_temporary()?);
let namespace = "default";
let scheduler = SchedulerServer::new(state.clone(),
namespace.to_owned());
- let state = SchedulerState::new(state);
+ let state = SchedulerState::new(state, namespace.to_string());
let exec_meta = ExecutorMetadata {
id: "abc".to_owned(),
host: "".to_owned(),
@@ -487,10 +469,7 @@ mod test {
// no response task since we told the scheduler we didn't want to
accept one
assert!(response.task.is_none());
// executor should be registered
- assert_eq!(
- state.get_executors_metadata(namespace).await.unwrap().len(),
- 1
- );
+ assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
let request: Request<PollWorkParams> = Request::new(PollWorkParams {
metadata: Some(exec_meta.clone()),
@@ -505,10 +484,7 @@ mod test {
// still no response task since there are no tasks in the scheduelr
assert!(response.task.is_none());
// executor should be registered
- assert_eq!(
- state.get_executors_metadata(namespace).await.unwrap().len(),
- 1
- );
+ assert_eq!(state.get_executors_metadata().await.unwrap().len(), 1);
Ok(())
}
}
diff --git a/rust/ballista/rust/scheduler/src/state/etcd.rs
b/rust/ballista/rust/scheduler/src/state/etcd.rs
index ced2461..807477d 100644
--- a/rust/ballista/rust/scheduler/src/state/etcd.rs
+++ b/rust/ballista/rust/scheduler/src/state/etcd.rs
@@ -17,15 +17,18 @@
//! Etcd config backend.
-use std::time::Duration;
+use std::{task::Poll, time::Duration};
use crate::state::ConfigBackendClient;
use ballista_core::error::{ballista_error, Result};
-use etcd_client::{GetOptions, LockResponse, PutOptions};
+use etcd_client::{
+ GetOptions, LockResponse, PutOptions, WatchOptions, WatchStream, Watcher,
+};
+use futures::{Stream, StreamExt};
use log::warn;
-use super::Lock;
+use super::{Lock, Watch, WatchEvent};
/// A [`ConfigBackendClient`] implementation that uses etcd to save cluster
configuration.
#[derive(Clone)]
@@ -105,6 +108,87 @@ impl ConfigBackendClient for EtcdClient {
})?;
Ok(Box::new(EtcdLockGuard { etcd, lock }))
}
+
+ async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+ let mut etcd = self.etcd.clone();
+ let options = WatchOptions::new().with_prefix();
+ let (watcher, stream) = etcd.watch(prefix,
Some(options)).await.map_err(|e| {
+ warn!("etcd watch failed: {}", e);
+ ballista_error("etcd watch failed")
+ })?;
+ Ok(Box::new(EtcdWatch {
+ watcher,
+ stream,
+ buffered_events: Vec::new(),
+ }))
+ }
+}
+
+struct EtcdWatch {
+ watcher: Watcher,
+ stream: WatchStream,
+ buffered_events: Vec<WatchEvent>,
+}
+
+#[tonic::async_trait]
+impl Watch for EtcdWatch {
+ async fn cancel(&mut self) -> Result<()> {
+ self.watcher.cancel().await.map_err(|e| {
+ warn!("etcd watch cancel failed: {}", e);
+ ballista_error("etcd watch cancel failed")
+ })
+ }
+}
+
+impl Stream for EtcdWatch {
+ type Item = WatchEvent;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let self_mut = self.get_mut();
+ if let Some(event) = self_mut.buffered_events.pop() {
+ Poll::Ready(Some(event))
+ } else {
+ loop {
+ match self_mut.stream.poll_next_unpin(cx) {
+ Poll::Ready(Some(Err(e))) => {
+ warn!("Error when watching etcd prefix: {}", e);
+ continue;
+ }
+ Poll::Ready(Some(Ok(v))) => {
+
self_mut.buffered_events.extend(v.events().iter().map(|ev| {
+ match ev.event_type() {
+ etcd_client::EventType::Put => {
+ let kv = ev.kv().unwrap();
+ WatchEvent::Put(
+ kv.key_str().unwrap().to_string(),
+ kv.value().to_owned(),
+ )
+ }
+ etcd_client::EventType::Delete => {
+ let kv = ev.kv().unwrap();
+
WatchEvent::Delete(kv.key_str().unwrap().to_string())
+ }
+ }
+ }));
+ if let Some(event) = self_mut.buffered_events.pop() {
+ return Poll::Ready(Some(event));
+ } else {
+ continue;
+ }
+ }
+ Poll::Ready(None) => return Poll::Ready(None),
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+ }
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.stream.size_hint()
+ }
}
struct EtcdLockGuard {
diff --git a/rust/ballista/rust/scheduler/src/state/mod.rs
b/rust/ballista/rust/scheduler/src/state/mod.rs
index 794e58f..a15efd6 100644
--- a/rust/ballista/rust/scheduler/src/state/mod.rs
+++ b/rust/ballista/rust/scheduler/src/state/mod.rs
@@ -20,7 +20,8 @@ use std::{
};
use datafusion::physical_plan::ExecutionPlan;
-use log::{debug, info};
+use futures::{Stream, StreamExt};
+use log::{debug, error, info};
use prost::Message;
use tokio::sync::OwnedMutexGuard;
@@ -69,27 +70,46 @@ pub trait ConfigBackendClient: Send + Sync {
) -> Result<()>;
async fn lock(&self) -> Result<Box<dyn Lock>>;
+
+ /// Watch all events that happen on a specific prefix.
+ async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>>;
+}
+
+/// A Watch is a cancelable stream of put or delete events in the
[ConfigBackendClient]
+#[tonic::async_trait]
+pub trait Watch: Stream<Item = WatchEvent> + Send + Unpin {
+ async fn cancel(&mut self) -> Result<()>;
+}
+
+#[derive(Debug, PartialEq)]
+pub enum WatchEvent {
+ /// Contains the inserted or updated key and the new value
+ Put(String, Vec<u8>),
+
+ /// Contains the deleted key
+ Delete(String),
}
#[derive(Clone)]
pub(super) struct SchedulerState {
config_client: Arc<dyn ConfigBackendClient>,
+ namespace: String,
}
impl SchedulerState {
- pub fn new(config_client: Arc<dyn ConfigBackendClient>) -> Self {
- Self { config_client }
+ pub fn new(config_client: Arc<dyn ConfigBackendClient>, namespace: String)
-> Self {
+ Self {
+ config_client,
+ namespace,
+ }
}
- pub async fn get_executors_metadata(
- &self,
- namespace: &str,
- ) -> Result<Vec<ExecutorMeta>> {
+ pub async fn get_executors_metadata(&self) -> Result<Vec<ExecutorMeta>> {
let mut result = vec![];
let entries = self
.config_client
- .get_from_prefix(&get_executors_prefix(namespace))
+ .get_from_prefix(&get_executors_prefix(&self.namespace))
.await?;
for (_key, entry) in entries {
let meta: ExecutorMetadata = decode_protobuf(&entry)?;
@@ -98,12 +118,8 @@ impl SchedulerState {
Ok(result)
}
- pub async fn save_executor_metadata(
- &self,
- namespace: &str,
- meta: ExecutorMeta,
- ) -> Result<()> {
- let key = get_executor_key(namespace, &meta.id);
+ pub async fn save_executor_metadata(&self, meta: ExecutorMeta) ->
Result<()> {
+ let key = get_executor_key(&self.namespace, &meta.id);
let meta: ExecutorMetadata = meta.into();
let value: Vec<u8> = encode_protobuf(&meta)?;
self.config_client.put(key, value, Some(LEASE_TIME)).await
@@ -111,22 +127,17 @@ impl SchedulerState {
pub async fn save_job_metadata(
&self,
- namespace: &str,
job_id: &str,
status: &JobStatus,
) -> Result<()> {
debug!("Saving job metadata: {:?}", status);
- let key = get_job_key(namespace, job_id);
+ let key = get_job_key(&self.namespace, job_id);
let value = encode_protobuf(status)?;
self.config_client.put(key, value, None).await
}
- pub async fn get_job_metadata(
- &self,
- namespace: &str,
- job_id: &str,
- ) -> Result<JobStatus> {
- let key = get_job_key(namespace, job_id);
+ pub async fn get_job_metadata(&self, job_id: &str) -> Result<JobStatus> {
+ let key = get_job_key(&self.namespace, job_id);
let value = &self.config_client.get(&key).await?;
if value.is_empty() {
return Err(BallistaError::General(format!(
@@ -138,14 +149,10 @@ impl SchedulerState {
Ok(value)
}
- pub async fn save_task_status(
- &self,
- namespace: &str,
- status: &TaskStatus,
- ) -> Result<()> {
+ pub async fn save_task_status(&self, status: &TaskStatus) -> Result<()> {
let partition_id = status.partition_id.as_ref().unwrap();
let key = get_task_status_key(
- namespace,
+ &self.namespace,
&partition_id.job_id,
partition_id.stage_id as usize,
partition_id.partition_id as usize,
@@ -156,12 +163,11 @@ impl SchedulerState {
pub async fn _get_task_status(
&self,
- namespace: &str,
job_id: &str,
stage_id: usize,
partition_id: usize,
) -> Result<TaskStatus> {
- let key = get_task_status_key(namespace, job_id, stage_id,
partition_id);
+ let key = get_task_status_key(&self.namespace, job_id, stage_id,
partition_id);
let value = &self.config_client.clone().get(&key).await?;
if value.is_empty() {
return Err(BallistaError::General(format!(
@@ -176,12 +182,11 @@ impl SchedulerState {
// "Unnecessary" lifetime syntax due to
https://github.com/rust-lang/rust/issues/63033
pub async fn save_stage_plan<'a>(
&'a self,
- namespace: &'a str,
job_id: &'a str,
stage_id: usize,
plan: Arc<dyn ExecutionPlan>,
) -> Result<()> {
- let key = get_stage_plan_key(namespace, job_id, stage_id);
+ let key = get_stage_plan_key(&self.namespace, job_id, stage_id);
let value = {
let proto: PhysicalPlanNode = plan.try_into()?;
encode_protobuf(&proto)?
@@ -191,11 +196,10 @@ impl SchedulerState {
pub async fn get_stage_plan(
&self,
- namespace: &str,
job_id: &str,
stage_id: usize,
) -> Result<Arc<dyn ExecutionPlan>> {
- let key = get_stage_plan_key(namespace, job_id, stage_id);
+ let key = get_stage_plan_key(&self.namespace, job_id, stage_id);
let value = &self.config_client.get(&key).await?;
if value.is_empty() {
return Err(BallistaError::General(format!(
@@ -209,26 +213,21 @@ impl SchedulerState {
pub async fn assign_next_schedulable_task(
&self,
- namespace: &str,
executor_id: &str,
) -> Result<Option<(TaskStatus, Arc<dyn ExecutionPlan>)>> {
let kvs: HashMap<String, Vec<u8>> = self
.config_client
- .get_from_prefix(&get_task_prefix(namespace))
+ .get_from_prefix(&get_task_prefix(&self.namespace))
.await?
.into_iter()
.collect();
- let executors = self.get_executors_metadata(namespace).await?;
+ let executors = self.get_executors_metadata().await?;
'tasks: for (_key, value) in kvs.iter() {
let mut status: TaskStatus = decode_protobuf(&value)?;
if status.status.is_none() {
let partition = status.partition_id.as_ref().unwrap();
let plan = self
- .get_stage_plan(
- namespace,
- &partition.job_id,
- partition.stage_id as usize,
- )
+ .get_stage_plan(&partition.job_id, partition.stage_id as
usize)
.await?;
// Let's try to resolve any unresolved shuffles we find
@@ -242,7 +241,7 @@ impl SchedulerState {
for partition_id in
0..unresolved_shuffle.partition_count {
let referenced_task = kvs
.get(&get_task_status_key(
- namespace,
+ &self.namespace,
&partition.job_id,
stage_id,
partition_id,
@@ -286,7 +285,7 @@ impl SchedulerState {
status.status = Some(task_status::Status::Running(RunningTask {
executor_id: executor_id.to_owned(),
}));
- self.save_task_status(namespace, &status).await?;
+ self.save_task_status(&status).await?;
return Ok(Some((status, plan)));
}
}
@@ -298,34 +297,58 @@ impl SchedulerState {
self.config_client.lock().await
}
- pub async fn synchronize_job_status(&self, namespace: &str) -> Result<()> {
- let kvs = self
+ /// This function starts a watch over the task keys. Whenever a task
changes, it re-evaluates
+ /// the status for the parent job and updates it accordingly.
+ ///
+ /// The future returned by this function never returns (unless an error
happens), so it is wise
+ /// to [tokio::spawn] calls to this method.
+ pub async fn synchronize_job_status_loop(&self) -> Result<()> {
+ let watch = self
+ .config_client
+ .watch(get_task_prefix(&self.namespace))
+ .await?;
+ watch.for_each(|event: WatchEvent| async {
+ let key = match event {
+ WatchEvent::Put(key, _value) => key,
+ WatchEvent::Delete(key) => key
+ };
+ let job_id = extract_job_id_from_task_key(&key).unwrap();
+ match self.lock().await {
+ Ok(mut lock) => {
+ if let Err(e) = self.synchronize_job_status(job_id).await {
+ error!("Could not update job status for {}. This job
might be stuck forever. Error: {}", job_id, e);
+ }
+ lock.unlock().await;
+ },
+ Err(e) => error!("Could not lock config backend. Job {} will
have an unsynchronized status and might be stuck forever. Error: {}", job_id, e)
+ }
+ }).await;
+
+ Ok(())
+ }
+
+ async fn synchronize_job_status(&self, job_id: &str) -> Result<()> {
+ let value = self
.config_client
- .get_from_prefix(&get_job_prefix(namespace))
+ .get(&get_job_key(&self.namespace, job_id))
.await?;
let executors: HashMap<String, ExecutorMeta> = self
- .get_executors_metadata(namespace)
+ .get_executors_metadata()
.await?
.into_iter()
.map(|meta| (meta.id.to_string(), meta))
.collect();
- for (key, value) in kvs {
- let job_id = extract_job_id_from_key(&key)?;
- let status: JobStatus = decode_protobuf(&value)?;
- let new_status = self
- .get_job_status_from_tasks(namespace, job_id, &executors)
- .await?;
- if let Some(new_status) = new_status {
- if status != new_status {
- info!(
- "Changing status for job {} to {:?}",
- job_id, new_status.status
- );
- debug!("Old status: {:?}", status);
- debug!("New status: {:?}", new_status);
- self.save_job_metadata(namespace, job_id, &new_status)
- .await?;
- }
+ let status: JobStatus = decode_protobuf(&value)?;
+ let new_status = self.get_job_status_from_tasks(job_id,
&executors).await?;
+ if let Some(new_status) = new_status {
+ if status != new_status {
+ info!(
+ "Changing status for job {} to {:?}",
+ job_id, new_status.status
+ );
+ debug!("Old status: {:?}", status);
+ debug!("New status: {:?}", new_status);
+ self.save_job_metadata(job_id, &new_status).await?;
}
}
Ok(())
@@ -333,13 +356,12 @@ impl SchedulerState {
async fn get_job_status_from_tasks(
&self,
- namespace: &str,
job_id: &str,
executors: &HashMap<String, ExecutorMeta>,
) -> Result<Option<JobStatus>> {
let statuses = self
.config_client
- .get_from_prefix(&get_task_prefix_for_job(namespace, job_id))
+ .get_from_prefix(&get_task_prefix_for_job(&self.namespace, job_id))
.await?
.into_iter()
.map(|(_k, v)| decode_protobuf::<TaskStatus>(&v))
@@ -446,12 +468,6 @@ fn get_job_prefix(namespace: &str) -> String {
format!("/ballista/{}/jobs", namespace)
}
-fn extract_job_id_from_key(job_key: &str) -> Result<&str> {
- job_key.split('/').nth(4).ok_or_else(|| {
- BallistaError::Internal(format!("Unexpected job key: {}", job_key))
- })
-}
-
fn get_job_key(namespace: &str, id: &str) -> String {
format!("{}/{}", get_job_prefix(namespace), id)
}
@@ -478,6 +494,12 @@ fn get_task_status_key(
)
}
+fn extract_job_id_from_task_key(job_key: &str) -> Result<&str> {
+ job_key.split('/').nth(4).ok_or_else(|| {
+ BallistaError::Internal(format!("Unexpected task key: {}", job_key))
+ })
+}
+
fn get_stage_plan_key(namespace: &str, job_id: &str, stage_id: usize) ->
String {
format!("/ballista/{}/stages/{}/{}", namespace, job_id, stage_id,)
}
@@ -514,44 +536,39 @@ mod test {
};
use ballista_core::{error::BallistaError, serde::scheduler::ExecutorMeta};
- use super::{SchedulerState, StandaloneClient};
+ use super::{
+ extract_job_id_from_task_key, get_task_status_key, SchedulerState,
+ StandaloneClient,
+ };
#[tokio::test]
async fn executor_metadata() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let meta = ExecutorMeta {
id: "123".to_owned(),
host: "localhost".to_owned(),
port: 123,
};
- state.save_executor_metadata("test", meta.clone()).await?;
- let result = state.get_executors_metadata("test").await?;
+ state.save_executor_metadata(meta.clone()).await?;
+ let result = state.get_executors_metadata().await?;
assert_eq!(vec![meta], result);
Ok(())
}
#[tokio::test]
- async fn executor_metadata_empty() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let meta = ExecutorMeta {
- id: "123".to_owned(),
- host: "localhost".to_owned(),
- port: 123,
- };
- state.save_executor_metadata("test", meta.clone()).await?;
- let result = state.get_executors_metadata("test2").await?;
- assert!(result.is_empty());
- Ok(())
- }
-
- #[tokio::test]
async fn job_metadata() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let meta = JobStatus {
status: Some(job_status::Status::Queued(QueuedJob {})),
};
- state.save_job_metadata("test", "job", &meta).await?;
- let result = state.get_job_metadata("test", "job").await?;
+ state.save_job_metadata("job", &meta).await?;
+ let result = state.get_job_metadata("job").await?;
assert!(result.status.is_some());
match result.status.unwrap() {
job_status::Status::Queued(_) => (),
@@ -562,19 +579,25 @@ mod test {
#[tokio::test]
async fn job_metadata_non_existant() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let meta = JobStatus {
status: Some(job_status::Status::Queued(QueuedJob {})),
};
- state.save_job_metadata("test", "job", &meta).await?;
- let result = state.get_job_metadata("test2", "job2").await;
+ state.save_job_metadata("job", &meta).await?;
+ let result = state.get_job_metadata("job2").await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn task_status() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let meta = TaskStatus {
status: Some(task_status::Status::Failed(FailedTask {
error: "error".to_owned(),
@@ -585,8 +608,8 @@ mod test {
partition_id: 2,
}),
};
- state.save_task_status("test", &meta).await?;
- let result = state._get_task_status("test", "job", 1, 2).await?;
+ state.save_task_status(&meta).await?;
+ let result = state._get_task_status("job", 1, 2).await?;
assert!(result.status.is_some());
match result.status.unwrap() {
task_status::Status::Failed(_) => (),
@@ -597,7 +620,10 @@ mod test {
#[tokio::test]
async fn task_status_non_existant() -> Result<(), BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let meta = TaskStatus {
status: Some(task_status::Status::Failed(FailedTask {
error: "error".to_owned(),
@@ -608,40 +634,40 @@ mod test {
partition_id: 2,
}),
};
- state.save_task_status("test", &meta).await?;
- let result = state._get_task_status("test", "job", 25, 2).await;
+ state.save_task_status(&meta).await?;
+ let result = state._get_task_status("job", 25, 2).await;
assert!(result.is_err());
Ok(())
}
#[tokio::test]
async fn task_synchronize_job_status_queued() -> Result<(), BallistaError>
{
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Queued(QueuedJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_job_metadata(job_id, &job_status).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
assert_eq!(result, job_status);
Ok(())
}
#[tokio::test]
async fn task_synchronize_job_status_running() -> Result<(),
BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Running(RunningJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
+ state.save_job_metadata(job_id, &job_status).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -652,7 +678,7 @@ mod test {
partition_id: 0,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Running(RunningTask {
executor_id: "".to_owned(),
@@ -663,24 +689,24 @@ mod test {
partition_id: 1,
}),
};
- state.save_task_status(namespace, &meta).await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_task_status(&meta).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
assert_eq!(result, job_status);
Ok(())
}
#[tokio::test]
async fn task_synchronize_job_status_running2() -> Result<(),
BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Running(RunningJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
+ state.save_job_metadata(job_id, &job_status).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -691,7 +717,7 @@ mod test {
partition_id: 0,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: None,
partition_id: Some(PartitionId {
@@ -700,24 +726,24 @@ mod test {
partition_id: 1,
}),
};
- state.save_task_status(namespace, &meta).await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_task_status(&meta).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
assert_eq!(result, job_status);
Ok(())
}
#[tokio::test]
async fn task_synchronize_job_status_completed() -> Result<(),
BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Running(RunningJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
+ state.save_job_metadata(job_id, &job_status).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -728,7 +754,7 @@ mod test {
partition_id: 0,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -739,9 +765,9 @@ mod test {
partition_id: 1,
}),
};
- state.save_task_status(namespace, &meta).await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_task_status(&meta).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
match result.status.unwrap() {
job_status::Status::Completed(_) => (),
status => panic!("Received status: {:?}", status),
@@ -751,15 +777,15 @@ mod test {
#[tokio::test]
async fn task_synchronize_job_status_completed2() -> Result<(),
BallistaError> {
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Queued(QueuedJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
+ state.save_job_metadata(job_id, &job_status).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -770,7 +796,7 @@ mod test {
partition_id: 0,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -781,9 +807,9 @@ mod test {
partition_id: 1,
}),
};
- state.save_task_status(namespace, &meta).await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_task_status(&meta).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
match result.status.unwrap() {
job_status::Status::Completed(_) => (),
status => panic!("Received status: {:?}", status),
@@ -793,15 +819,15 @@ mod test {
#[tokio::test]
async fn task_synchronize_job_status_failed() -> Result<(), BallistaError>
{
- let state =
SchedulerState::new(Arc::new(StandaloneClient::try_new_temporary()?));
- let namespace = "default";
+ let state = SchedulerState::new(
+ Arc::new(StandaloneClient::try_new_temporary()?),
+ "test".to_string(),
+ );
let job_id = "job";
let job_status = JobStatus {
status: Some(job_status::Status::Running(RunningJob {})),
};
- state
- .save_job_metadata(namespace, job_id, &job_status)
- .await?;
+ state.save_job_metadata(job_id, &job_status).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Completed(CompletedTask {
executor_id: "".to_owned(),
@@ -812,7 +838,7 @@ mod test {
partition_id: 0,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: Some(task_status::Status::Failed(FailedTask {
error: "".to_owned(),
@@ -823,7 +849,7 @@ mod test {
partition_id: 1,
}),
};
- state.save_task_status(namespace, &meta).await?;
+ state.save_task_status(&meta).await?;
let meta = TaskStatus {
status: None,
partition_id: Some(PartitionId {
@@ -832,13 +858,23 @@ mod test {
partition_id: 2,
}),
};
- state.save_task_status(namespace, &meta).await?;
- state.synchronize_job_status(namespace).await?;
- let result = state.get_job_metadata(namespace, job_id).await?;
+ state.save_task_status(&meta).await?;
+ state.synchronize_job_status(job_id).await?;
+ let result = state.get_job_metadata(job_id).await?;
match result.status.unwrap() {
job_status::Status::Failed(_) => (),
status => panic!("Received status: {:?}", status),
}
Ok(())
}
+
+ #[test]
+ fn task_extract_job_id_from_task_key() {
+ let job_id = "foo";
+ assert_eq!(
+ extract_job_id_from_task_key(&get_task_status_key("namespace",
job_id, 0, 1))
+ .unwrap(),
+ job_id
+ );
+ }
}
diff --git a/rust/ballista/rust/scheduler/src/state/standalone.rs
b/rust/ballista/rust/scheduler/src/state/standalone.rs
index e07d45e..69805c0 100644
--- a/rust/ballista/rust/scheduler/src/state/standalone.rs
+++ b/rust/ballista/rust/scheduler/src/state/standalone.rs
@@ -15,15 +15,17 @@
// specific language governing permissions and limitations
// under the License.
-use std::{sync::Arc, time::Duration};
+use std::{sync::Arc, task::Poll, time::Duration};
use crate::state::ConfigBackendClient;
use ballista_core::error::{ballista_error, BallistaError, Result};
+use futures::{FutureExt, Stream};
use log::warn;
+use sled::{Event, Subscriber};
use tokio::sync::Mutex;
-use super::Lock;
+use super::{Lock, Watch, WatchEvent};
/// A [`ConfigBackendClient`] implementation that uses file-based storage to
save cluster configuration.
#[derive(Clone)]
@@ -106,13 +108,57 @@ impl ConfigBackendClient for StandaloneClient {
async fn lock(&self) -> Result<Box<dyn Lock>> {
Ok(Box::new(self.lock.clone().lock_owned().await))
}
+
+ async fn watch(&self, prefix: String) -> Result<Box<dyn Watch>> {
+ Ok(Box::new(SledWatch {
+ subscriber: self.db.watch_prefix(prefix),
+ }))
+ }
+}
+
+struct SledWatch {
+ subscriber: Subscriber,
+}
+
+#[tonic::async_trait]
+impl Watch for SledWatch {
+ async fn cancel(&mut self) -> Result<()> {
+ Ok(())
+ }
+}
+
+impl Stream for SledWatch {
+ type Item = WatchEvent;
+
+ fn poll_next(
+ self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> std::task::Poll<Option<Self::Item>> {
+ match self.get_mut().subscriber.poll_unpin(cx) {
+ Poll::Pending => Poll::Pending,
+ Poll::Ready(None) => Poll::Ready(None),
+ Poll::Ready(Some(Event::Insert { key, value })) => {
+ let key = std::str::from_utf8(&key).unwrap().to_owned();
+ Poll::Ready(Some(WatchEvent::Put(key, value.to_vec())))
+ }
+ Poll::Ready(Some(Event::Remove { key })) => {
+ let key = std::str::from_utf8(&key).unwrap().to_owned();
+ Poll::Ready(Some(WatchEvent::Delete(key)))
+ }
+ }
+ }
+
+ fn size_hint(&self) -> (usize, Option<usize>) {
+ self.subscriber.size_hint()
+ }
}
#[cfg(test)]
mod tests {
- use crate::state::ConfigBackendClient;
+ use crate::state::{ConfigBackendClient, Watch, WatchEvent};
use super::StandaloneClient;
+ use futures::StreamExt;
use std::result::Result;
fn create_instance() -> Result<StandaloneClient, Box<dyn
std::error::Error>> {
@@ -158,4 +204,25 @@ mod tests {
);
Ok(())
}
+
+ #[tokio::test]
+ async fn read_watch() -> Result<(), Box<dyn std::error::Error>> {
+ let client = create_instance()?;
+ let key = "key";
+ let value = "value".as_bytes();
+ let mut watch: Box<dyn Watch> = client.watch(key.to_owned()).await?;
+ client.put(key.to_owned(), value.to_vec(), None).await?;
+ assert_eq!(
+ watch.next().await,
+ Some(WatchEvent::Put(key.to_owned(), value.to_owned()))
+ );
+ let value2 = "value2".as_bytes();
+ client.put(key.to_owned(), value2.to_vec(), None).await?;
+ assert_eq!(
+ watch.next().await,
+ Some(WatchEvent::Put(key.to_owned(), value2.to_owned()))
+ );
+ watch.cancel().await?;
+ Ok(())
+ }
}