This is an automated email from the ASF dual-hosted git repository. thinkharderdev pushed a commit to branch cluster-state-refactor-2 in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
commit 950882bd0e18f2c626dc0939d7fd28972f601b8f Author: Dan Harris <[email protected]> AuthorDate: Fri Feb 3 09:49:08 2023 +0200 WIP --- ballista/core/proto/ballista.proto | 29 +- ballista/core/src/serde/generated/ballista.rs | 46 +- ballista/scheduler/src/api/handlers.rs | 5 +- ballista/scheduler/src/bin/main.rs | 14 +- ballista/scheduler/src/cluster/event/mod.rs | 318 ++++++++++++++ ballista/scheduler/src/cluster/kv.rs | 480 +++++++++++++++++++++ ballista/scheduler/src/cluster/memory.rs | 353 +++++++++++++++ ballista/scheduler/src/cluster/mod.rs | 287 ++++++++++++ ballista/scheduler/src/cluster/storage/etcd.rs | 346 +++++++++++++++ ballista/scheduler/src/cluster/storage/mod.rs | 137 ++++++ ballista/scheduler/src/cluster/storage/sled.rs | 401 +++++++++++++++++ ballista/scheduler/src/lib.rs | 1 + ballista/scheduler/src/scheduler_process.rs | 5 +- ballista/scheduler/src/scheduler_server/mod.rs | 14 +- ballista/scheduler/src/standalone.rs | 15 +- ballista/scheduler/src/state/execution_graph.rs | 51 ++- .../scheduler/src/state/execution_graph_dot.rs | 14 +- ballista/scheduler/src/state/executor_manager.rs | 4 +- ballista/scheduler/src/state/mod.rs | 17 +- ballista/scheduler/src/state/task_manager.rs | 258 ++++------- 20 files changed, 2586 insertions(+), 209 deletions(-) diff --git a/ballista/core/proto/ballista.proto b/ballista/core/proto/ballista.proto index 76d3521f..6c9a021a 100644 --- a/ballista/core/proto/ballista.proto +++ b/ballista/core/proto/ballista.proto @@ -651,6 +651,7 @@ message ExecutorMetadata { ExecutorSpecification specification = 5; } + // Used by grpc message ExecutorRegistration { string id = 1; @@ -698,6 +699,15 @@ message ExecutorResource { } } +message AvailableTaskSlots { + string executor_id = 1; + uint32 slots = 2; + } + +message ExecutorTaskSlots { + repeated AvailableTaskSlots task_slots = 1; +} + message ExecutorData { string executor_id = 1; repeated ExecutorResourcePair resources = 2; @@ -905,18 +915,33 @@ message GetJobStatusParams { message SuccessfulJob { repeated PartitionLocation partition_location = 1; + uint64 queued_at = 2; + uint64 started_at = 3; + uint64 ended_at = 4; } -message QueuedJob {} +message QueuedJob { + uint64 queued_at = 1; +} // TODO: add progress report -message RunningJob {} +message RunningJob { + uint64 queued_at = 1; + uint64 started_at = 2; + string scheduler = 3; +} message FailedJob { string error = 1; + uint64 queued_at = 2; + uint64 started_at = 3; + uint64 ended_at = 4; } message JobStatus { + string job_id = 5; + string job_name = 6; + oneof status { QueuedJob queued = 1; RunningJob running = 2; diff --git a/ballista/core/src/serde/generated/ballista.rs b/ballista/core/src/serde/generated/ballista.rs index 8f0bc9bc..37b937ba 100644 --- a/ballista/core/src/serde/generated/ballista.rs +++ b/ballista/core/src/serde/generated/ballista.rs @@ -1058,6 +1058,18 @@ pub mod executor_resource { } } #[derive(Clone, PartialEq, ::prost::Message)] +pub struct AvailableTaskSlots { + #[prost(string, tag = "1")] + pub executor_id: ::prost::alloc::string::String, + #[prost(uint32, tag = "2")] + pub slots: u32, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ExecutorTaskSlots { + #[prost(message, repeated, tag = "1")] + pub task_slots: ::prost::alloc::vec::Vec<AvailableTaskSlots>, +} +#[derive(Clone, PartialEq, ::prost::Message)] pub struct ExecutorData { #[prost(string, tag = "1")] pub executor_id: ::prost::alloc::string::String, @@ -1367,19 +1379,45 @@ pub struct GetJobStatusParams { pub struct SuccessfulJob { #[prost(message, repeated, tag = "1")] pub partition_location: ::prost::alloc::vec::Vec<PartitionLocation>, + #[prost(uint64, tag = "2")] + pub queued_at: u64, + #[prost(uint64, tag = "3")] + pub started_at: u64, + #[prost(uint64, tag = "4")] + pub ended_at: u64, } #[derive(Clone, PartialEq, ::prost::Message)] -pub struct QueuedJob {} +pub struct QueuedJob { + #[prost(uint64, tag = "1")] + pub queued_at: u64, +} /// TODO: add progress report #[derive(Clone, PartialEq, ::prost::Message)] -pub struct RunningJob {} +pub struct RunningJob { + #[prost(uint64, tag = "1")] + pub queued_at: u64, + #[prost(uint64, tag = "2")] + pub started_at: u64, + #[prost(string, tag = "3")] + pub scheduler: ::prost::alloc::string::String, +} #[derive(Clone, PartialEq, ::prost::Message)] pub struct FailedJob { #[prost(string, tag = "1")] pub error: ::prost::alloc::string::String, + #[prost(uint64, tag = "2")] + pub queued_at: u64, + #[prost(uint64, tag = "3")] + pub started_at: u64, + #[prost(uint64, tag = "4")] + pub ended_at: u64, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JobStatus { + #[prost(string, tag = "5")] + pub job_id: ::prost::alloc::string::String, + #[prost(string, tag = "6")] + pub job_name: ::prost::alloc::string::String, #[prost(oneof = "job_status::Status", tags = "1, 2, 3, 4")] pub status: ::core::option::Option<job_status::Status>, } @@ -1985,7 +2023,7 @@ pub mod executor_grpc_client { pub mod scheduler_grpc_server { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - ///Generated trait containing gRPC methods that should be implemented for use with SchedulerGrpcServer. + /// Generated trait containing gRPC methods that should be implemented for use with SchedulerGrpcServer. #[async_trait] pub trait SchedulerGrpc: Send + Sync + 'static { /// Executors must poll the scheduler for heartbeat and to receive tasks @@ -2531,7 +2569,7 @@ pub mod scheduler_grpc_server { pub mod executor_grpc_server { #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] use tonic::codegen::*; - ///Generated trait containing gRPC methods that should be implemented for use with ExecutorGrpcServer. + /// Generated trait containing gRPC methods that should be implemented for use with ExecutorGrpcServer. #[async_trait] pub trait ExecutorGrpc: Send + Sync + 'static { async fn launch_task( diff --git a/ballista/scheduler/src/api/handlers.rs b/ballista/scheduler/src/api/handlers.rs index 7e77aa34..c8ac6970 100644 --- a/ballista/scheduler/src/api/handlers.rs +++ b/ballista/scheduler/src/api/handlers.rs @@ -209,6 +209,7 @@ pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: AsExecutionPlan>( { Ok(warp::reply::json(&QueryStagesResponse { stages: graph + .as_ref() .stages() .iter() .map(|(id, stage)| { @@ -303,7 +304,7 @@ pub(crate) async fn get_job_dot_graph<T: AsLogicalPlan, U: AsExecutionPlan>( .await .map_err(|_| warp::reject())? { - ExecutionGraphDot::generate(graph).map_err(|_| warp::reject()) + ExecutionGraphDot::generate(graph.as_ref()).map_err(|_| warp::reject()) } else { Ok("Not Found".to_string()) } @@ -322,7 +323,7 @@ pub(crate) async fn get_query_stage_dot_graph<T: AsLogicalPlan, U: AsExecutionPl .await .map_err(|_| warp::reject())? { - ExecutionGraphDot::generate_for_query_stage(graph, stage_id) + ExecutionGraphDot::generate_for_query_stage(graph.as_ref(), stage_id) .map_err(|_| warp::reject()) } else { Ok("Not Found".to_string()) diff --git a/ballista/scheduler/src/bin/main.rs b/ballista/scheduler/src/bin/main.rs index 2ad29bd2..ca248825 100644 --- a/ballista/scheduler/src/bin/main.rs +++ b/ballista/scheduler/src/bin/main.rs @@ -44,6 +44,9 @@ mod config { } use ballista_core::config::LogRotationPolicy; +use ballista_core::utils::default_session_builder; +use ballista_scheduler::cluster::memory::{InMemoryClusterState, InMemoryJobState}; +use ballista_scheduler::cluster::BallistaCluster; use ballista_scheduler::config::SchedulerConfig; use ballista_scheduler::state::backend::cluster::DefaultClusterState; use config::prelude::*; @@ -131,7 +134,16 @@ async fn main() -> Result<()> { .finished_job_state_clean_up_interval_seconds, advertise_flight_sql_endpoint: opt.advertise_flight_sql_endpoint, }; - start_server(scheduler_name, config_backend, cluster_state, addr, config).await?; + + let cluster = BallistaCluster::new( + Arc::new(InMemoryClusterState::default()), + Arc::new(InMemoryJobState::new( + &scheduler_name, + default_session_builder, + )), + ); + + start_server(scheduler_name, cluster, config_backend, addr, config).await?; Ok(()) } diff --git a/ballista/scheduler/src/cluster/event/mod.rs b/ballista/scheduler/src/cluster/event/mod.rs new file mode 100644 index 00000000..88b3c297 --- /dev/null +++ b/ballista/scheduler/src/cluster/event/mod.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use futures::Stream; +use log::debug; +use parking_lot::RwLock; +use std::collections::BTreeMap; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use tokio::sync::broadcast; +use tokio::sync::broadcast::error::TryRecvError; + +// TODO make configurable +const EVENT_BUFFER_SIZE: usize = 256; + +static ID_GEN: AtomicUsize = AtomicUsize::new(0); + +#[derive(Default)] +struct Shared { + subscriptions: AtomicUsize, + wakers: RwLock<BTreeMap<usize, Waker>>, +} + +impl Shared { + pub fn register(&self, subscriber_id: usize, waker: Waker) { + self.wakers.write().insert(subscriber_id, waker); + } + + pub fn deregister(&self, subscriber_id: usize) { + self.wakers.write().remove(&subscriber_id); + } + + pub fn notify(&self) { + let guard = self.wakers.read(); + for waker in guard.values() { + waker.wake_by_ref(); + } + } +} + +pub(crate) struct ClusterEventSender<T: Clone> { + sender: broadcast::Sender<T>, + shared: Arc<Shared>, +} + +impl<T: Clone> ClusterEventSender<T> { + pub fn new(capacity: usize) -> Self { + let (sender, _) = broadcast::channel(capacity); + + Self { + sender, + shared: Arc::new(Shared::default()), + } + } + + pub fn send(&self, event: &T) { + if self.shared.subscriptions.load(Ordering::Acquire) > 0 { + if let Err(e) = self.sender.send(event.clone()) { + debug!("Failed to send event to channel: {}", e); + return; + } + + self.shared.notify(); + } + } + + pub fn subscribe(&self) -> EventSubscriber<T> { + self.shared.subscriptions.fetch_add(1, Ordering::AcqRel); + let id = ID_GEN.fetch_add(1, Ordering::AcqRel); + + EventSubscriber { + id, + receiver: self.sender.subscribe(), + shared: self.shared.clone(), + registered: false, + } + } + + #[cfg(test)] + pub fn registered_wakers(&self) -> usize { + self.shared.wakers.read().len() + } +} + +impl<T: Clone> Default for ClusterEventSender<T> { + fn default() -> Self { + Self::new(EVENT_BUFFER_SIZE) + } +} + +pub struct EventSubscriber<T: Clone> { + id: usize, + receiver: broadcast::Receiver<T>, + shared: Arc<Shared>, + registered: bool, +} + +impl<T: Clone> EventSubscriber<T> { + pub fn register(&mut self, waker: Waker) { + if !self.registered { + self.shared.register(self.id, waker); + self.registered = true; + } + } +} + +impl<T: Clone> Drop for EventSubscriber<T> { + fn drop(&mut self) { + self.shared.deregister(self.id); + } +} + +impl<T: Clone> Stream for EventSubscriber<T> { + type Item = T; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Self::Item>> { + loop { + match self.receiver.try_recv() { + Ok(event) => { + self.register(cx.waker().clone()); + return Poll::Ready(Some(event)); + } + Err(TryRecvError::Closed) => return Poll::Ready(None), + Err(TryRecvError::Lagged(n)) => { + debug!("Subscriber lagged by {} message", n); + self.register(cx.waker().clone()); + continue; + } + Err(TryRecvError::Empty) => { + self.register(cx.waker().clone()); + return Poll::Pending; + } + } + } + } +} + +#[cfg(test)] +mod test { + use crate::cluster::event::{ClusterEventSender, EventSubscriber}; + use futures::stream::FuturesUnordered; + use futures::StreamExt; + + async fn collect_events<T: Clone>(mut rx: EventSubscriber<T>) -> Vec<T> { + let mut events = vec![]; + while let Some(event) = rx.next().await { + events.push(event); + } + + events + } + + #[tokio::test] + async fn test_event_subscription() { + let sender = ClusterEventSender::new(100); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + tokio::spawn(async move { + for i in 0..100 { + sender.send(&i); + } + }); + + let expected: Vec<i32> = (0..100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_lagged() { + // Created sender with a buffer for only 8 events + let sender = ClusterEventSender::new(8); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + // Send events faster than they can be consumed by subscribers + tokio::spawn(async move { + for i in 0..100 { + sender.send(&i); + } + }); + + // When we reach capacity older events should be dropped so we only see + // the last 8 events in our subscribers + let expected: Vec<i32> = (92..100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_skip_unsubscribed() { + let sender = ClusterEventSender::new(100); + + // There are no subscribers yet so this event should be ignored + sender.send(&0); + + let rx = vec![sender.subscribe(), sender.subscribe(), sender.subscribe()]; + + let mut tasks: FuturesUnordered<_> = rx + .into_iter() + .map(|rx| async move { collect_events(rx).await }) + .collect(); + + let handle = tokio::spawn(async move { + let mut results = vec![]; + while let Some(result) = tasks.next().await { + results.push(result) + } + results + }); + + tokio::spawn(async move { + for i in 1..=100 { + sender.send(&i); + } + }); + + let expected: Vec<i32> = (1..=100).into_iter().collect(); + + let results = handle.await.unwrap(); + assert_eq!(results.len(), 3); + + for res in results { + assert_eq!(res, expected); + } + } + + #[tokio::test] + async fn test_event_register_wakers() { + let sender = ClusterEventSender::new(100); + + let mut rx_1 = sender.subscribe(); + let mut rx_2 = sender.subscribe(); + let mut rx_3 = sender.subscribe(); + + sender.send(&0); + + // Subscribers haven't been polled yet so expect not registered wakers + assert_eq!(sender.registered_wakers(), 0); + + let event = rx_1.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 1); + + let event = rx_2.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 2); + + let event = rx_3.next().await; + assert_eq!(event, Some(0)); + assert_eq!(sender.registered_wakers(), 3); + + drop(rx_1); + assert_eq!(sender.registered_wakers(), 2); + + drop(rx_2); + assert_eq!(sender.registered_wakers(), 1); + + drop(rx_3); + assert_eq!(sender.registered_wakers(), 0); + } +} diff --git a/ballista/scheduler/src/cluster/kv.rs b/ballista/scheduler/src/cluster/kv.rs new file mode 100644 index 00000000..d130f339 --- /dev/null +++ b/ballista/scheduler/src/cluster/kv.rs @@ -0,0 +1,480 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cluster::storage::{KeyValueStore, Keyspace, Lock, Operation, WatchEvent}; +use crate::cluster::{ + reserve_slots_bias, reserve_slots_round_robin, ClusterState, ExecutorHeartbeatStream, + JobState, JobStateEventStream, JobStatus, TaskDistribution, +}; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use crate::state::{decode_into, decode_protobuf}; +use async_trait::async_trait; +use ballista_core::config::BallistaConfig; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::{ + self, AvailableTaskSlots, ExecutorHeartbeat, ExecutorTaskSlots, +}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use ballista_core::serde::AsExecutionPlan; +use ballista_core::serde::BallistaCodec; +use datafusion::prelude::SessionContext; +use datafusion_proto::logical_plan::AsLogicalPlan; +use futures::StreamExt; +use itertools::Itertools; +use prost::Message; +use std::collections::{HashMap, HashSet}; +use std::future::Future; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use dashmap::DashMap; + +/// State implementation based on underlying `KeyValueStore` +pub struct KeyValueState< + S: KeyValueStore, + T: 'static + AsLogicalPlan, + U: 'static + AsExecutionPlan, +> { + /// Underlying `KeyValueStore` + store: S, + /// Codec used to serialize/deserialize execution plan + codec: BallistaCodec<T, U>, + /// Name of current scheduler. Should be `{host}:{port}` + scheduler: String, + /// In-memory store of queued jobs. Map from Job ID -> (Job Name, queued_at timestamp) + queued_jobs: DashMap<String, (String, u64)>, +} + +impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> + KeyValueState<S, T, U> +{ + pub fn new( + scheduler: impl Into<String>, + store: S, + codec: BallistaCodec<T, U>, + ) -> Self { + Self { + store, + scheduler: scheduler.into(), + codec, + queued_jobs: DashMap::new(), + } + } +} + +#[async_trait] +impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> + ClusterState for KeyValueState<S, T, U> +{ + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>> { + let lock = self.store.lock(Keyspace::Slots, "global").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {:?}", + err + )) + })?; + + let slots_iter = slots.task_slots.iter_mut().filter(|slots| { + executors + .as_ref() + .map(|executors| executors.contains(&slots.executor_id)) + .unwrap_or(true) + }); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(slots_iter, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(slots_iter, num_slots) + } + }; + + if !reservations.is_empty() { + self.store + .put(Keyspace::Slots, "all".to_owned(), slots.encode_to_vec()) + .await? + } + + Ok(reservations) + }) + .await + } + + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>> { + let lock = self.store.lock(Keyspace::Slots, "global").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {:?}", + err + )) + })?; + + let slots_iter = slots.task_slots.iter_mut().filter(|slots| { + executors + .as_ref() + .map(|executors| executors.contains(&slots.executor_id)) + .unwrap_or(true) + }); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(slots_iter, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(slots_iter, num_slots) + } + }; + + if reservations.len() == num_slots as usize { + self.store + .put(Keyspace::Slots, "all".to_owned(), slots.encode_to_vec()) + .await?; + Ok(reservations) + } else { + Ok(vec![]) + } + }) + .await + } + + async fn cancel_reservations( + &self, + reservations: Vec<ExecutorReservation>, + ) -> Result<()> { + let lock = self.store.lock(Keyspace::Slots, "global").await?; + + with_lock(lock, async { + let resources = self.store.get(Keyspace::Slots, "all").await?; + + let mut slots = + ExecutorTaskSlots::decode(resources.as_slice()).map_err(|err| { + BallistaError::Internal(format!( + "Unexpected value in executor slots state: {:?}", + err + )) + })?; + + let mut increments = HashMap::new(); + for ExecutorReservation { executor_id, .. } in reservations { + if let Some(inc) = increments.get_mut(&executor_id) { + *inc += 1; + } else { + increments.insert(executor_id, 1usize); + } + } + + for executor_slots in slots.task_slots.iter_mut() { + if let Some(slots) = increments.get(&executor_slots.executor_id) { + executor_slots.slots += *slots as u32; + } + } + + Ok(()) + }) + .await + } + + async fn register_executor( + &self, + metadata: ExecutorMetadata, + spec: ExecutorData, + reserve: bool, + ) -> Result<Vec<ExecutorReservation>> { + let executor_id = metadata.id.clone(); + + let current_ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| { + BallistaError::Internal(format!( + "Error getting current timestamp: {:?}", + e + )) + })? + .as_secs(); + + //TODO this should be in a transaction + // Now that we know we can connect, save the metadata and slots + self.save_executor_metadata(metadata).await?; + self.save_executor_heartbeat(ExecutorHeartbeat { + executor_id: executor_id.clone(), + timestamp: current_ts, + metrics: vec![], + status: Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Active("".to_string())), + }), + }) + .await?; + + if !reserve { + let proto: protobuf::ExecutorData = spec.into(); + self.store + .put(Keyspace::Slots, executor_id, proto.encode_to_vec()) + .await?; + Ok(vec![]) + } else { + let mut specification = spec; + let num_slots = specification.available_task_slots as usize; + let mut reservations: Vec<ExecutorReservation> = vec![]; + for _ in 0..num_slots { + reservations.push(ExecutorReservation::new_free(executor_id.clone())); + } + + specification.available_task_slots = 0; + + let proto: protobuf::ExecutorData = specification.into(); + self.store + .put(Keyspace::Slots, executor_id, proto.encode_to_vec()) + .await?; + Ok(reservations) + } + } + + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> { + let executor_id = metadata.id.clone(); + let proto: protobuf::ExecutorMetadata = metadata.into(); + + self.store + .put(Keyspace::Executors, executor_id, proto.encode_to_vec()) + .await + } + + async fn get_executor_metadata(&self, executor_id: &str) -> Result<ExecutorMetadata> { + let value = self.store.get(Keyspace::Executors, executor_id).await?; + + let decoded = + decode_into::<protobuf::ExecutorMetadata, ExecutorMetadata>(&value)?; + Ok(decoded) + } + + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()> { + let executor_id = heartbeat.executor_id.clone(); + self.store + .put(Keyspace::Heartbeats, executor_id, heartbeat.encode_to_vec()) + .await + } + + async fn remove_executor(&self, executor_id: &str) -> Result<()> { + let current_ts = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_err(|e| { + BallistaError::Internal(format!( + "Error getting current timestamp: {:?}", + e + )) + })? + .as_secs(); + + let value = ExecutorHeartbeat { + executor_id: executor_id.to_owned(), + timestamp: current_ts, + metrics: vec![], + status: Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Dead("".to_string())), + }), + } + .encode_to_vec(); + + self.store + .put(Keyspace::Heartbeats, executor_id.to_owned(), value) + .await?; + + // TODO Check the Executor reservation logic for push-based scheduling + + Ok(()) + } + + async fn executor_heartbeat_stream(&self) -> Result<ExecutorHeartbeatStream> { + let events = self + .store + .watch(Keyspace::Heartbeats, String::default()) + .await?; + + Ok(events + .filter_map(|event| { + futures::future::ready(match event { + WatchEvent::Put(_, value) => { + if let Ok(heartbeat) = + decode_protobuf::<ExecutorHeartbeat>(&value) + { + Some(heartbeat) + } else { + None + } + } + WatchEvent::Delete(_) => None, + }) + }) + .boxed()) + } + + async fn executor_heartbeats(&self) -> Result<HashMap<String, ExecutorHeartbeat>> { + let heartbeats = self.store.scan(Keyspace::Heartbeats, None).await?; + + let mut heartbeat_map = HashMap::with_capacity(heartbeats.len()); + + for (_, value) in heartbeats { + let data: ExecutorHeartbeat = decode_protobuf(&value)?; + if let Some(protobuf::ExecutorStatus { + status: Some(protobuf::executor_status::Status::Active(_)), + }) = &data.status + { + heartbeat_map.insert(data.executor_id.clone(), data); + } + } + + Ok(heartbeat_map) + } +} + +#[async_trait] +impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> JobState + for KeyValueState<S, T, U> +{ + async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()> { + todo!() + } + + async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()> { + let status = graph.status(); + let encoded_graph = + ExecutionGraph::encode_execution_graph(graph.clone(), &self.codec)?; + + self.store + .apply_txn(vec![ + ( + Operation::Put(status.encode_to_vec()), + Keyspace::JobStatus, + job_id.clone(), + ), + ( + Operation::Put(encoded_graph.encode_to_vec()), + Keyspace::ExecutionGraph, + job_id.clone(), + ), + ]) + .await?; + + Ok(()) + } + + async fn get_jobs(&self) -> Result<HashSet<String>> { + self.store.scan_keys(Keyspace::JobStatus).await + } + + async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> { + let value = self.store.get(Keyspace::JobStatus, job_id).await?; + + (!value.is_empty()) + .then(|| decode_protobuf(value.as_slice())) + .transpose() + } + + async fn get_execution_graph(&self, job_id: &str) -> Result<Option<ExecutionGraph>> { + let value = self.store.get(Keyspace::ExecutionGraph, job_id).await?; + + if value.is_empty() { + return Ok(None); + } + + let proto: protobuf::ExecutionGraph = decode_protobuf(value.as_slice())?; + + let session = self.get_session(&proto.session_id).await?; + + Ok(Some( + ExecutionGraph::decode_execution_graph(proto, &self.codec, session.as_ref()) + .await?, + )) + } + + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()> { + let status = graph.status(); + let encoded_graph = + ExecutionGraph::encode_execution_graph(graph.clone(), &self.codec)?; + + self.store + .apply_txn(vec![ + ( + Operation::Put(status.encode_to_vec()), + Keyspace::JobStatus, + job_id.to_string(), + ), + ( + Operation::Put(encoded_graph.encode_to_vec()), + Keyspace::ExecutionGraph, + job_id.to_string(), + ), + ]) + .await + } + + async fn remove_job(&self, job_id: &str) -> Result<()> { + todo!() + } + + async fn try_acquire_job(&self, _job_id: &str) -> Result<Option<ExecutionGraph>> { + Err(BallistaError::NotImplemented( + "Work stealing is not currently implemented".to_string(), + )) + } + + async fn job_state_events(&self) -> JobStateEventStream { + todo!() + } + + async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> { + todo!() + } + + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>> { + todo!() + } + + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>> { + todo!() + } +} + +async fn with_lock<Out, F: Future<Output = Out>>(mut lock: Box<dyn Lock>, op: F) -> Out { + let result = op.await; + lock.unlock().await; + result +} diff --git a/ballista/scheduler/src/cluster/memory.rs b/ballista/scheduler/src/cluster/memory.rs new file mode 100644 index 00000000..258444ec --- /dev/null +++ b/ballista/scheduler/src/cluster/memory.rs @@ -0,0 +1,353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cluster::{ + reserve_slots_bias, reserve_slots_round_robin, ClusterState, ExecutorHeartbeatStream, + JobState, JobStateEvent, JobStateEventStream, JobStatus, TaskDistribution, +}; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use async_trait::async_trait; +use ballista_core::config::BallistaConfig; +use ballista_core::error::{BallistaError, Result}; +use ballista_core::serde::protobuf::{ExecutorHeartbeat, ExecutorTaskSlots}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use dashmap::DashMap; +use datafusion::prelude::SessionContext; + +use crate::cluster::event::ClusterEventSender; +use crate::scheduler_server::SessionBuilder; +use crate::state::session_manager::{ + create_datafusion_context, update_datafusion_context, +}; +use ballista_core::serde::protobuf::job_status::Status; +use itertools::Itertools; +use log::warn; +use parking_lot::Mutex; +use std::collections::{HashMap, HashSet}; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +#[derive(Default)] +pub struct InMemoryClusterState { + /// Current available task slots for each executor + task_slots: Mutex<ExecutorTaskSlots>, + /// Current executors + executors: DashMap<String, (ExecutorMetadata, ExecutorData)>, + /// Last heartbeat received for each executor + heartbeats: DashMap<String, ExecutorHeartbeat>, + /// Broadcast channel sender for heartbeats, If `None` there are not + /// subscribers + heartbeat_sender: ClusterEventSender<ExecutorHeartbeat>, +} + +#[async_trait] +impl ClusterState for InMemoryClusterState { + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>> { + let mut guard = self.task_slots.lock(); + + let slots_iter = guard.task_slots.iter_mut().filter(|slots| { + executors + .as_ref() + .map(|executors| executors.contains(&slots.executor_id)) + .unwrap_or(true) + }); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(slots_iter, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(slots_iter, num_slots) + } + }; + + Ok(reservations) + } + + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>> { + let mut guard = self.task_slots.lock(); + + let rollback = guard.clone(); + + let slots_iter = guard.task_slots.iter_mut().filter(|slots| { + executors + .as_ref() + .map(|executors| executors.contains(&slots.executor_id)) + .unwrap_or(true) + }); + + let reservations = match distribution { + TaskDistribution::Bias => reserve_slots_bias(slots_iter, num_slots), + TaskDistribution::RoundRobin => { + reserve_slots_round_robin(slots_iter, num_slots) + } + }; + + if reservations.len() as u32 != num_slots { + *guard = rollback; + Ok(vec![]) + } else { + Ok(reservations) + } + } + + async fn cancel_reservations( + &self, + reservations: Vec<ExecutorReservation>, + ) -> Result<()> { + let mut increments = HashMap::new(); + for ExecutorReservation { executor_id, .. } in reservations { + if let Some(inc) = increments.get_mut(&executor_id) { + *inc += 1; + } else { + increments.insert(executor_id, 1usize); + } + } + + let mut guard = self.task_slots.lock(); + + for executor_slots in guard.task_slots.iter_mut() { + if let Some(slots) = increments.get(&executor_slots.executor_id) { + executor_slots.slots += *slots as u32; + } + } + + Ok(()) + } + + async fn register_executor( + &self, + metadata: ExecutorMetadata, + mut spec: ExecutorData, + reserve: bool, + ) -> Result<Vec<ExecutorReservation>> { + if reserve { + let slots = std::mem::take(&mut spec.available_task_slots) as usize; + + let reservations = (0..slots) + .into_iter() + .map(|_| ExecutorReservation::new_free(metadata.id.clone())) + .collect(); + + self.executors.insert(metadata.id.clone(), (metadata, spec)); + + Ok(reservations) + } else { + self.executors.insert(metadata.id.clone(), (metadata, spec)); + + Ok(vec![]) + } + } + + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()> { + todo!() + } + + async fn get_executor_metadata(&self, executor_id: &str) -> Result<ExecutorMetadata> { + todo!() + } + + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()> { + if let Some(mut last) = self.heartbeats.get_mut(&heartbeat.executor_id) { + let _ = std::mem::replace(last.deref_mut(), heartbeat.clone()); + } else { + self.heartbeats + .insert(heartbeat.executor_id.clone(), heartbeat.clone()); + } + + self.heartbeat_sender.send(&heartbeat); + + Ok(()) + } + + async fn remove_executor(&self, executor_id: &str) -> Result<()> { + { + let mut guard = self.task_slots.lock(); + + if let Some((idx, _)) = guard + .task_slots + .iter() + .find_position(|slots| slots.executor_id == executor_id) + { + guard.task_slots.swap_remove(idx); + } + } + + self.executors.remove(executor_id); + self.heartbeats.remove(executor_id); + + Ok(()) + } + + async fn executor_heartbeat_stream(&self) -> Result<ExecutorHeartbeatStream> { + Ok(Box::pin(self.heartbeat_sender.subscribe())) + } + + async fn executor_heartbeats(&self) -> Result<HashMap<String, ExecutorHeartbeat>> { + Ok(self + .heartbeats + .iter() + .map(|r| (r.key().clone(), r.value().clone())) + .collect()) + } +} + +/// Implementation of `JobState` which keeps all state in memory. If using `InMemoryJobState` +/// no job state will be shared between schedulers +pub struct InMemoryJobState { + scheduler: String, + /// Jobs which have either completed successfully or failed + completed_jobs: DashMap<String, (JobStatus, Option<ExecutionGraph>)>, + /// In-memory store of queued jobs. Map from Job ID -> (Job Name, queued_at timestamp) + queued_jobs: DashMap<String, (String, u64)>, + /// Active ballista sessions + sessions: DashMap<String, Arc<SessionContext>>, + /// `SessionBuilder` for building DataFusion `SessionContext` from `BallistaConfig` + session_builder: SessionBuilder, + /// Sender of job events + job_event_sender: ClusterEventSender<JobStateEvent>, +} + +impl InMemoryJobState { + pub fn new(scheduler: impl Into<String>, session_builder: SessionBuilder) -> Self { + Self { + scheduler: scheduler.into(), + completed_jobs: Default::default(), + queued_jobs: Default::default(), + sessions: Default::default(), + session_builder, + job_event_sender: ClusterEventSender::new(100), + } + } +} + +#[async_trait] +impl JobState for InMemoryJobState { + + async fn submit_job(&self, job_id: String, _graph: &ExecutionGraph) -> Result<()> { + self.job_event_sender.send(&JobStateEvent::JobAcquired { + job_id, + owner: self.scheduler.clone(), + }); + + Ok(()) + } + + async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> { + Ok(self + .completed_jobs + .get(job_id) + .map(|pair| pair.value().0.clone())) + } + + async fn get_execution_graph(&self, job_id: &str) -> Result<Option<ExecutionGraph>> { + Ok(self + .completed_jobs + .get(job_id) + .and_then(|pair| pair.value().1.clone())) + } + + async fn try_acquire_job(&self, _job_id: &str) -> Result<Option<ExecutionGraph>> { + // Always return None. The only state stored here are for completed jobs + // which cannot be acquired + Ok(None) + } + + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()> { + let status = graph.status(); + + // If job is either successful or failed, save to completed jobs + if matches!( + status.status, + Some(Status::Successful(_)) | Some(Status::Failed(_)) + ) { + self.completed_jobs + .insert(job_id.to_string(), (status, Some(graph.clone()))); + } + + Ok(()) + } + + async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> { + self.sessions + .get(session_id) + .map(|sess| sess.clone()) + .ok_or_else(|| { + BallistaError::General(format!("No session for {} found", session_id)) + }) + } + + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>> { + let session = create_datafusion_context(config, self.session_builder); + self.sessions.insert(session.session_id(), session.clone()); + + Ok(session) + } + + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>> { + if let Some(mut session) = self.sessions.get_mut(session_id) { + *session = update_datafusion_context(session.clone(), config); + Ok(session.clone()) + } else { + let session = create_datafusion_context(config, self.session_builder); + self.sessions + .insert(session_id.to_string(), session.clone()); + + Ok(session) + } + } + + async fn job_state_events(&self) -> JobStateEventStream { + Box::pin(self.job_event_sender.subscribe()) + } + + async fn remove_job(&self, job_id: &str) -> Result<()> { + if self.completed_jobs.remove(job_id).is_none() { + warn!("Tried to delete non-existent job {job_id} from state"); + } + Ok(()) + } + + async fn get_jobs(&self) -> Result<HashSet<String>> { + Ok(self + .completed_jobs + .iter() + .map(|pair| pair.key().clone()) + .collect()) + } + + async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()> { + todo!() + } +} diff --git a/ballista/scheduler/src/cluster/mod.rs b/ballista/scheduler/src/cluster/mod.rs new file mode 100644 index 00000000..da7c16e8 --- /dev/null +++ b/ballista/scheduler/src/cluster/mod.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod event; +pub mod kv; +pub mod memory; +pub mod storage; + +use crate::cluster::storage::KeyValueStore; +use crate::state::execution_graph::ExecutionGraph; +use crate::state::executor_manager::ExecutorReservation; +use ballista_core::config::BallistaConfig; +use ballista_core::error::Result; +use ballista_core::serde::protobuf::{AvailableTaskSlots, ExecutorHeartbeat, JobStatus}; +use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata}; +use ballista_core::serde::AsExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_proto::logical_plan::AsLogicalPlan; +use futures::Stream; +use std::collections::{HashMap, HashSet}; +use std::pin::Pin; +use std::sync::Arc; + +pub struct BallistaCluster { + cluster_state: Arc<dyn ClusterState>, + job_state: Arc<dyn JobState>, +} + +impl BallistaCluster { + pub fn new( + cluster_state: Arc<dyn ClusterState>, + job_state: Arc<dyn JobState>, + ) -> Self { + Self { + cluster_state, + job_state, + } + } + + pub fn cluster_state(&self) -> Arc<dyn ClusterState> { + self.cluster_state.clone() + } + + pub fn job_state(&self) -> Arc<dyn JobState> { + self.job_state.clone() + } +} + +/// Stream of `ExecutorHeartbeat`. This stream should contain all `ExecutorHeartbeats` received +/// by any schedulers with a shared `ClusterState` +pub type ExecutorHeartbeatStream = Pin<Box<dyn Stream<Item = ExecutorHeartbeat> + Send>>; + +/// Method of distributing tasks to available executor slots +#[derive(Debug, Clone, Copy)] +pub enum TaskDistribution { + /// Eagerly assign tasks to executor slots. This will assign as many task slots per executor + /// as are currently available + Bias, + /// Distributed tasks evenely across executors. This will try and iterate through available executors + /// and assign one task to each executor until all tasks are assigned. + RoundRobin, +} + +/// A trait that contains the necessary method to maintain a globally consistent view of cluster resources +#[tonic::async_trait] +pub trait ClusterState: Send + Sync { + /// Reserve up to `num_slots` executor task slots. If not enough task slots are available, reserve + /// as many as possible. + /// + /// If `executors` is provided, only reserve slots of the specified executor IDs + async fn reserve_slots( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>>; + + /// Reserve exactly `num_slots` executor task slots. If not enough task slots are available, + /// returns an empty vec + /// + /// If `executors` is provided, only reserve slots of the specified executor IDs + async fn reserve_slots_exact( + &self, + num_slots: u32, + distribution: TaskDistribution, + executors: Option<HashSet<String>>, + ) -> Result<Vec<ExecutorReservation>>; + + /// Cancel the specified reservations. This will make reserved executor slots available to other + /// tasks. + /// This operations should be atomic. Either all reservations are cancelled or none are + async fn cancel_reservations( + &self, + reservations: Vec<ExecutorReservation>, + ) -> Result<()>; + + /// Register a new executor in the cluster. If `reserve` is true, then the executors task slots + /// will be reserved and returned in the response and none of the new executors task slots will be + /// available to other tasks. + async fn register_executor( + &self, + metadata: ExecutorMetadata, + spec: ExecutorData, + reserve: bool, + ) -> Result<Vec<ExecutorReservation>>; + + /// Save the executor metadata. This will overwrite existing metadata for the executor ID + async fn save_executor_metadata(&self, metadata: ExecutorMetadata) -> Result<()>; + + /// Get executor metadata for the provided executor ID. Returns an error if the executor does not exist + async fn get_executor_metadata(&self, executor_id: &str) -> Result<ExecutorMetadata>; + + /// Save the executor heartbeat + async fn save_executor_heartbeat(&self, heartbeat: ExecutorHeartbeat) -> Result<()>; + + /// Remove the executor from the cluster + async fn remove_executor(&self, executor_id: &str) -> Result<()>; + + /// Return the stream of executor heartbeats observed by all schedulers in the cluster. + /// This can be aggregated to provide an eventually consistent view of all executors within the cluster + async fn executor_heartbeat_stream(&self) -> Result<ExecutorHeartbeatStream>; + + /// Return a map of the last seen heartbeat for all active executors + async fn executor_heartbeats(&self) -> Result<HashMap<String, ExecutorHeartbeat>>; +} + +#[derive(Debug, Clone)] +pub enum JobStateEvent { + /// Event when a job status has been updated + JobUpdated { + /// Job ID of updated job + job_id: String, + /// New job status + status: JobStatus, + }, + /// Event when a scheduler acquires ownership of the job. This happens + /// either when a scheduler submits a job (in which case ownership is implied) + /// or when a scheduler acquires ownership of a running job release by a + /// different scheduler + JobAcquired { + /// Job ID of the acquired job + job_id: String, + /// The scheduler which acquired ownership of the job + owner: String, + }, + /// Event when a scheduler releases ownership of a still active job + JobReleased { + /// Job ID of the released job + job_id: String, + }, + /// Event when a new session has been created + SessionCreated { + session_id: String, + config: BallistaConfig, + }, + /// Event when a session configuration has been updated + SessionUpdated { + session_id: String, + config: BallistaConfig, + }, +} + +/// Stream of `JobStateEvent`. This stream should contain all `JobStateEvent`s received +/// by any schedulers with a shared `ClusterState` +pub type JobStateEventStream = Pin<Box<dyn Stream<Item = JobStateEvent> + Send>>; + +/// A trait that contains the necessary methods for persisting state related to executing jobs +#[tonic::async_trait] +pub trait JobState: Send + Sync { + /// Accept job into a scheduler's job queue. This should be called when a job is + /// received by the scheduler but before it is planned and may or may not be saved + /// in global state + async fn accept_job(&self, job_id: &str, job_name: &str, queued_at: u64) -> Result<()>; + + /// Submit a new job to the `JobState`. It is assumed that the submitter owns the job. + /// In local state the job should be save as `JobStatus::Active` and in shared state + /// it should be saved as `JobStatus::Running` with `scheduler` set to the current scheduler + async fn submit_job(&self, job_id: String, graph: &ExecutionGraph) -> Result<()>; + + /// Return a `Vec` of all active job IDs in the `JobState` + async fn get_jobs(&self) -> Result<HashSet<String>>; + + /// Fetch the job status + async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>>; + + /// Get the `ExecutionGraph` for job. The job may or may not belong to the caller + /// and should return the `ExecutionGraph` for the given job (if it exists) at the + /// time this method is called with no guarantees that the graph has not been + /// subsequently updated by another scheduler. + async fn get_execution_graph(&self, job_id: &str) -> Result<Option<ExecutionGraph>>; + + /// Persist the current state of an owned job to global state. This should fail + /// if the job is not owned by the caller. + async fn save_job(&self, job_id: &str, graph: &ExecutionGraph) -> Result<()>; + + /// Delete a job from the global state + async fn remove_job(&self, job_id: &str) -> Result<()>; + + /// Attempt to acquire ownership of the given job. If the job is still in a running state + /// and is successfully acquired by the caller, return the current `ExecutionGraph`, + /// otherwise return `None` + async fn try_acquire_job(&self, job_id: &str) -> Result<Option<ExecutionGraph>>; + + /// Get a stream of all `JobState` events. An event should be published any time that status + /// of a job changes in state + async fn job_state_events(&self) -> JobStateEventStream; + + /// Get the `SessionContext` associated with `session_id`. Returns an error if the + /// session does not exist + async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>>; + + /// Create a new saved session + async fn create_session( + &self, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>>; + + // Update a new saved session. If the session does not exist, a new one will be created + async fn update_session( + &self, + session_id: &str, + config: &BallistaConfig, + ) -> Result<Arc<SessionContext>>; +} + +pub(crate) fn reserve_slots_bias<'a, I: Iterator<Item = &'a mut AvailableTaskSlots>>( + mut slots: I, + mut n: u32, +) -> Vec<ExecutorReservation> { + let mut reservations = Vec::with_capacity(n as usize); + + while n > 0 { + if let Some(executor) = slots.next() { + let take = executor.slots.min(n); + for _ in 0..take { + reservations + .push(ExecutorReservation::new_free(executor.executor_id.clone())); + } + + executor.slots -= take; + n -= take; + } else { + break; + } + } + + reservations +} + +pub(crate) fn reserve_slots_round_robin< + 'a, + I: Iterator<Item = &'a mut AvailableTaskSlots>, +>( + mut slots: I, + mut n: u32, +) -> Vec<ExecutorReservation> { + let mut reservations = Vec::with_capacity(n as usize); + + while n > 0 { + if let Some(executor) = slots.next() { + if executor.slots > 0 { + reservations + .push(ExecutorReservation::new_free(executor.executor_id.clone())); + executor.slots -= 1; + n -= 1; + } + } else { + break; + } + } + + reservations +} diff --git a/ballista/scheduler/src/cluster/storage/etcd.rs b/ballista/scheduler/src/cluster/storage/etcd.rs new file mode 100644 index 00000000..62fb4292 --- /dev/null +++ b/ballista/scheduler/src/cluster/storage/etcd.rs @@ -0,0 +1,346 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashSet; + +use std::task::Poll; + +use async_trait::async_trait; +use ballista_core::error::{ballista_error, Result}; +use std::time::Instant; + +use crate::cluster::storage::KeyValueStore; +use etcd_client::{ + GetOptions, LockOptions, LockResponse, Txn, TxnOp, WatchOptions, WatchStream, Watcher, +}; +use futures::{Stream, StreamExt}; +use log::{debug, error, warn}; + +use crate::cluster::storage::{Keyspace, Lock, Operation, Watch, WatchEvent}; + +/// A [`StateBackendClient`] implementation that uses etcd to save cluster state. +#[derive(Clone)] +pub struct EtcdClient { + namespace: String, + etcd: etcd_client::Client, +} + +impl EtcdClient { + pub fn new(namespace: String, etcd: etcd_client::Client) -> Self { + Self { namespace, etcd } + } +} + +#[async_trait] +impl KeyValueStore for EtcdClient { + async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>> { + let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key); + + Ok(self + .etcd + .clone() + .get(key, None) + .await + .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))? + .kvs() + .get(0) + .map(|kv| kv.value().to_owned()) + .unwrap_or_default()) + } + + async fn get_from_prefix( + &self, + keyspace: Keyspace, + prefix: &str, + ) -> Result<Vec<(String, Vec<u8>)>> { + let prefix = format!("/{}/{:?}/{}", self.namespace, keyspace, prefix); + + Ok(self + .etcd + .clone() + .get(prefix, Some(GetOptions::new().with_prefix())) + .await + .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))? + .kvs() + .iter() + .map(|kv| (kv.key_str().unwrap().to_owned(), kv.value().to_owned())) + .collect()) + } + + async fn scan( + &self, + keyspace: Keyspace, + limit: Option<usize>, + ) -> Result<Vec<(String, Vec<u8>)>> { + let prefix = format!("/{}/{:?}/", self.namespace, keyspace); + + let options = if let Some(limit) = limit { + GetOptions::new().with_prefix().with_limit(limit as i64) + } else { + GetOptions::new().with_prefix() + }; + + Ok(self + .etcd + .clone() + .get(prefix, Some(options)) + .await + .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))? + .kvs() + .iter() + .map(|kv| (kv.key_str().unwrap().to_owned(), kv.value().to_owned())) + .collect()) + } + + async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>> { + let prefix = format!("/{}/{:?}/", self.namespace, keyspace); + + let options = GetOptions::new().with_prefix().with_keys_only(); + + Ok(self + .etcd + .clone() + .get(prefix.clone(), Some(options)) + .await + .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))? + .kvs() + .iter() + .map(|kv| { + kv.key_str() + .unwrap() + .strip_prefix(&prefix) + .unwrap() + .to_owned() + }) + .collect()) + } + + async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()> { + let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key); + + let mut etcd = self.etcd.clone(); + etcd.put(key, value.clone(), None) + .await + .map_err(|e| { + warn!("etcd put failed: {}", e); + ballista_error(&format!("etcd put failed: {}", e)) + }) + .map(|_| ()) + } + + /// Apply multiple operations in a single transaction. + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()> { + let mut etcd = self.etcd.clone(); + + let txn_ops: Vec<TxnOp> = ops + .into_iter() + .map(|(operation, ks, key)| { + let key = format!("/{}/{:?}/{}", self.namespace, ks, key); + match operation { + Operation::Put(value) => TxnOp::put(key, value, None), + Operation::Delete => TxnOp::delete(key, None), + } + }) + .collect(); + + etcd.txn(Txn::new().and_then(txn_ops)) + .await + .map_err(|e| { + error!("etcd operation failed: {}", e); + ballista_error(&format!("etcd operation failed: {}", e)) + }) + .map(|_| ()) + } + + async fn mv( + &self, + from_keyspace: Keyspace, + to_keyspace: Keyspace, + key: &str, + ) -> Result<()> { + let mut etcd = self.etcd.clone(); + let from_key = format!("/{}/{:?}/{}", self.namespace, from_keyspace, key); + let to_key = format!("/{}/{:?}/{}", self.namespace, to_keyspace, key); + + let current_value = etcd + .get(from_key.as_str(), None) + .await + .map_err(|e| ballista_error(&format!("etcd error {:?}", e)))? + .kvs() + .get(0) + .map(|kv| kv.value().to_owned()); + + if let Some(value) = current_value { + let txn = Txn::new().and_then(vec![ + TxnOp::delete(from_key.as_str(), None), + TxnOp::put(to_key.as_str(), value, None), + ]); + etcd.txn(txn).await.map_err(|e| { + error!("etcd put failed: {}", e); + ballista_error("etcd move failed") + })?; + } else { + warn!("Cannot move value at {}, does not exist", from_key); + } + + Ok(()) + } + + async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>> { + let start = Instant::now(); + let mut etcd = self.etcd.clone(); + + let lock_id = format!("/{}/mutex/{:?}/{}", self.namespace, keyspace, key); + + // Create a lease which expires after 30 seconds. We then associate this lease with the lock + // acquired below. This protects against a scheduler dying unexpectedly while holding locks + // on shared resources. In that case, those locks would expire once the lease expires. + // TODO This is not great to do for every lock. We should have a single lease per scheduler instance + let lease_id = etcd + .lease_client() + .grant(30, None) + .await + .map_err(|e| { + warn!("etcd lease failed: {}", e); + ballista_error("etcd lease failed") + })? + .id(); + + let lock_options = LockOptions::new().with_lease(lease_id); + + let lock = etcd + .lock(lock_id.as_str(), Some(lock_options)) + .await + .map_err(|e| { + warn!("etcd lock failed: {}", e); + ballista_error("etcd lock failed") + })?; + + let elapsed = start.elapsed(); + debug!("Acquired lock {} in {:?}", lock_id, elapsed); + Ok(Box::new(EtcdLockGuard { etcd, lock })) + } + + async fn watch(&self, keyspace: Keyspace, prefix: String) -> Result<Box<dyn Watch>> { + let prefix = format!("/{}/{:?}/{}", self.namespace, keyspace, prefix); + + let mut etcd = self.etcd.clone(); + let options = WatchOptions::new().with_prefix(); + let (watcher, stream) = etcd.watch(prefix, Some(options)).await.map_err(|e| { + warn!("etcd watch failed: {}", e); + ballista_error("etcd watch failed") + })?; + Ok(Box::new(EtcdWatch { + watcher, + stream, + buffered_events: Vec::new(), + })) + } + + async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()> { + let key = format!("/{}/{:?}/{}", self.namespace, keyspace, key); + + let mut etcd = self.etcd.clone(); + + etcd.delete(key, None).await.map_err(|e| { + warn!("etcd delete failed: {:?}", e); + ballista_error("etcd delete failed") + })?; + + Ok(()) + } +} + +struct EtcdWatch { + watcher: Watcher, + stream: WatchStream, + buffered_events: Vec<WatchEvent>, +} + +#[tonic::async_trait] +impl Watch for EtcdWatch { + async fn cancel(&mut self) -> Result<()> { + self.watcher.cancel().await.map_err(|e| { + warn!("etcd watch cancel failed: {}", e); + ballista_error("etcd watch cancel failed") + }) + } +} + +impl Stream for EtcdWatch { + type Item = WatchEvent; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + let self_mut = self.get_mut(); + if let Some(event) = self_mut.buffered_events.pop() { + Poll::Ready(Some(event)) + } else { + loop { + match self_mut.stream.poll_next_unpin(cx) { + Poll::Ready(Some(Err(e))) => { + warn!("Error when watching etcd prefix: {}", e); + continue; + } + Poll::Ready(Some(Ok(v))) => { + self_mut.buffered_events.extend(v.events().iter().map(|ev| { + match ev.event_type() { + etcd_client::EventType::Put => { + let kv = ev.kv().unwrap(); + WatchEvent::Put( + kv.key_str().unwrap().to_string(), + kv.value().to_owned(), + ) + } + etcd_client::EventType::Delete => { + let kv = ev.kv().unwrap(); + WatchEvent::Delete(kv.key_str().unwrap().to_string()) + } + } + })); + if let Some(event) = self_mut.buffered_events.pop() { + return Poll::Ready(Some(event)); + } else { + continue; + } + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.stream.size_hint() + } +} + +struct EtcdLockGuard { + etcd: etcd_client::Client, + lock: LockResponse, +} + +// Cannot use Drop because we need this to be async +#[tonic::async_trait] +impl Lock for EtcdLockGuard { + async fn unlock(&mut self) { + self.etcd.unlock(self.lock.key()).await.unwrap(); + } +} diff --git a/ballista/scheduler/src/cluster/storage/mod.rs b/ballista/scheduler/src/cluster/storage/mod.rs new file mode 100644 index 00000000..0cee7a1e --- /dev/null +++ b/ballista/scheduler/src/cluster/storage/mod.rs @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#[cfg(feature = "etcd")] +pub mod etcd; +#[cfg(feature = "sled")] +pub mod sled; + +use async_trait::async_trait; +use ballista_core::error::Result; +use futures::{future, Stream}; +use std::collections::HashSet; +use tokio::sync::OwnedMutexGuard; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub enum Keyspace { + Executors, + JobStatus, + ExecutionGraph, + ActiveJobs, + CompletedJobs, + FailedJobs, + Slots, + Sessions, + Heartbeats, +} + +#[derive(Debug, Eq, PartialEq, Hash)] +pub enum Operation { + Put(Vec<u8>), + Delete, +} + +/// A trait that defines a KeyValue interface with basic locking primitives for persisting Ballista cluster state +#[async_trait] +pub trait KeyValueStore: Send + Sync + Clone { + /// Retrieve the data associated with a specific key in a given keyspace. + /// + /// An empty vec is returned if the key does not exist. + async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>>; + + /// Retrieve all key/value pairs in given keyspace matching a given key prefix. + async fn get_from_prefix( + &self, + keyspace: Keyspace, + prefix: &str, + ) -> Result<Vec<(String, Vec<u8>)>>; + + /// Retrieve all key/value pairs in a given keyspace. If a limit is specified, will return at + /// most `limit` key-value pairs. + async fn scan( + &self, + keyspace: Keyspace, + limit: Option<usize>, + ) -> Result<Vec<(String, Vec<u8>)>>; + + /// Retrieve all keys from a given keyspace (without their values). The implementations + /// should handle stripping any prefixes it may add. + async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>>; + + /// Saves the value into the provided key, overriding any previous data that might have been associated to that key. + async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()>; + + /// Bundle multiple operation in a single transaction. Either all values should be saved, or all should fail. + /// It can support multiple types of operations and keyspaces. If the count of the unique keyspace is more than one, + /// more than one locks has to be acquired. + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()>; + /// Acquire mutex with specified IDs. + async fn acquire_locks( + &self, + mut ids: Vec<(Keyspace, &str)>, + ) -> Result<Vec<Box<dyn Lock>>> { + // We always acquire locks in a specific order to avoid deadlocks. + ids.sort_by_key(|n| format!("/{:?}/{}", n.0, n.1)); + future::try_join_all(ids.into_iter().map(|(ks, key)| self.lock(ks, key))).await + } + + /// Atomically move the given key from one keyspace to another + async fn mv( + &self, + from_keyspace: Keyspace, + to_keyspace: Keyspace, + key: &str, + ) -> Result<()>; + + /// Acquire mutex with specified ID. + async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>>; + + /// Watch all events that happen on a specific prefix. + async fn watch( + &self, + keyspace: Keyspace, + prefix: String, + ) -> Result<Box<dyn Watch<Item = WatchEvent>>>; + + /// Permanently delete a key from state + async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()>; +} + +/// A Watch is a cancelable stream of put or delete events in the [StateBackendClient] +#[async_trait] +pub trait Watch: Stream<Item = WatchEvent> + Send + Unpin { + async fn cancel(&mut self) -> Result<()>; +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum WatchEvent { + /// Contains the inserted or updated key and the new value + Put(String, Vec<u8>), + + /// Contains the deleted key + Delete(String), +} + +#[async_trait] +pub trait Lock: Send + Sync { + async fn unlock(&mut self); +} + +#[async_trait] +impl<T: Send + Sync> Lock for OwnedMutexGuard<T> { + async fn unlock(&mut self) {} +} diff --git a/ballista/scheduler/src/cluster/storage/sled.rs b/ballista/scheduler/src/cluster/storage/sled.rs new file mode 100644 index 00000000..67700d2c --- /dev/null +++ b/ballista/scheduler/src/cluster/storage/sled.rs @@ -0,0 +1,401 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::{HashMap, HashSet}; +use std::{sync::Arc, task::Poll}; + +use ballista_core::error::{ballista_error, BallistaError, Result}; + +use crate::cluster::storage::KeyValueStore; +use async_trait::async_trait; +use futures::{FutureExt, Stream}; +use log::warn; +use sled_package as sled; +use tokio::sync::Mutex; + +use crate::cluster::storage::{Keyspace, Lock, Operation, Watch, WatchEvent}; + +/// A [`StateBackendClient`] implementation that uses file-based storage to save cluster state. +#[derive(Clone)] +pub struct SledClient { + db: sled::Db, + locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>, +} + +impl SledClient { + /// Creates a SledClient that saves data to the specified file. + pub fn try_new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> { + Ok(Self { + db: sled::open(path).map_err(sled_to_ballista_error)?, + locks: Arc::new(Mutex::new(HashMap::new())), + }) + } + + /// Creates a SledClient that saves data to a temp file. + pub fn try_new_temporary() -> Result<Self> { + Ok(Self { + db: sled::Config::new() + .temporary(true) + .open() + .map_err(sled_to_ballista_error)?, + locks: Arc::new(Mutex::new(HashMap::new())), + }) + } +} + +fn sled_to_ballista_error(e: sled::Error) -> BallistaError { + match e { + sled::Error::Io(io) => BallistaError::IoError(io), + _ => BallistaError::General(format!("{}", e)), + } +} + +#[async_trait] +impl KeyValueStore for SledClient { + async fn get(&self, keyspace: Keyspace, key: &str) -> Result<Vec<u8>> { + let key = format!("/{:?}/{}", keyspace, key); + Ok(self + .db + .get(key) + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))? + .map(|v| v.to_vec()) + .unwrap_or_default()) + } + + async fn get_from_prefix( + &self, + keyspace: Keyspace, + prefix: &str, + ) -> Result<Vec<(String, Vec<u8>)>> { + let prefix = format!("/{:?}/{}", keyspace, prefix); + Ok(self + .db + .scan_prefix(prefix) + .map(|v| { + v.map(|(key, value)| { + ( + std::str::from_utf8(&key).unwrap().to_owned(), + value.to_vec(), + ) + }) + }) + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?) + } + + async fn scan( + &self, + keyspace: Keyspace, + limit: Option<usize>, + ) -> Result<Vec<(String, Vec<u8>)>> { + let prefix = format!("/{:?}/", keyspace); + if let Some(limit) = limit { + Ok(self + .db + .scan_prefix(prefix) + .take(limit) + .map(|v| { + v.map(|(key, value)| { + ( + std::str::from_utf8(&key).unwrap().to_owned(), + value.to_vec(), + ) + }) + }) + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?) + } else { + Ok(self + .db + .scan_prefix(prefix) + .map(|v| { + v.map(|(key, value)| { + ( + std::str::from_utf8(&key).unwrap().to_owned(), + value.to_vec(), + ) + }) + }) + .collect::<std::result::Result<Vec<_>, _>>() + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?) + } + } + + async fn scan_keys(&self, keyspace: Keyspace) -> Result<HashSet<String>> { + let prefix = format!("/{:?}/", keyspace); + Ok(self + .db + .scan_prefix(prefix.clone()) + .map(|v| { + v.map(|(key, _value)| { + std::str::from_utf8(&key) + .unwrap() + .strip_prefix(&prefix) + .unwrap() + .to_owned() + }) + }) + .collect::<std::result::Result<HashSet<_>, _>>() + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))?) + } + + async fn put(&self, keyspace: Keyspace, key: String, value: Vec<u8>) -> Result<()> { + let key = format!("/{:?}/{}", keyspace, key); + self.db + .insert(key, value) + .map_err(|e| { + warn!("sled insert failed: {}", e); + ballista_error("sled insert failed") + }) + .map(|_| ()) + } + + async fn apply_txn(&self, ops: Vec<(Operation, Keyspace, String)>) -> Result<()> { + let mut batch = sled::Batch::default(); + + for (op, keyspace, key_str) in ops { + let key = format!("/{:?}/{}", &keyspace, key_str); + match op { + Operation::Put(value) => batch.insert(key.as_str(), value), + Operation::Delete => batch.remove(key.as_str()), + } + } + + self.db.apply_batch(batch).map_err(|e| { + warn!("sled transaction insert failed: {}", e); + ballista_error("sled operations failed") + }) + } + + async fn mv( + &self, + from_keyspace: Keyspace, + to_keyspace: Keyspace, + key: &str, + ) -> Result<()> { + let from_key = format!("/{:?}/{}", from_keyspace, key); + let to_key = format!("/{:?}/{}", to_keyspace, key); + + let current_value = self + .db + .get(from_key.as_str()) + .map_err(|e| ballista_error(&format!("sled error {:?}", e)))? + .map(|v| v.to_vec()); + + if let Some(value) = current_value { + let mut batch = sled::Batch::default(); + + batch.remove(from_key.as_str()); + batch.insert(to_key.as_str(), value); + + self.db.apply_batch(batch).map_err(|e| { + warn!("sled transaction insert failed: {}", e); + ballista_error("sled insert failed") + }) + } else { + // TODO should this return an error? + warn!("Cannot move value at {}, does not exist", from_key); + Ok(()) + } + } + + async fn lock(&self, keyspace: Keyspace, key: &str) -> Result<Box<dyn Lock>> { + let mut mlock = self.locks.lock().await; + let lock_key = format!("/{:?}/{}", keyspace, key); + if let Some(lock) = mlock.get(&lock_key) { + Ok(Box::new(lock.clone().lock_owned().await)) + } else { + let new_lock = Arc::new(Mutex::new(())); + mlock.insert(lock_key, new_lock.clone()); + Ok(Box::new(new_lock.lock_owned().await)) + } + } + + async fn watch( + &self, + keyspace: Keyspace, + prefix: String, + ) -> Result<Box<dyn Watch<Item = WatchEvent>>> { + let prefix = format!("/{:?}/{}", keyspace, prefix); + + Ok(Box::new(SledWatch { + subscriber: self.db.watch_prefix(prefix), + })) + } + + async fn delete(&self, keyspace: Keyspace, key: &str) -> Result<()> { + let key = format!("/{:?}/{}", keyspace, key); + self.db.remove(key).map_err(|e| { + warn!("sled delete failed: {:?}", e); + ballista_error("sled delete failed") + })?; + Ok(()) + } +} + +struct SledWatch { + subscriber: sled::Subscriber, +} + +#[tonic::async_trait] +impl Watch for SledWatch { + async fn cancel(&mut self) -> Result<()> { + Ok(()) + } +} + +impl Stream for SledWatch { + type Item = WatchEvent; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll<Option<Self::Item>> { + match self.get_mut().subscriber.poll_unpin(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(sled::Event::Insert { key, value })) => { + let key = std::str::from_utf8(&key).unwrap().to_owned(); + Poll::Ready(Some(WatchEvent::Put(key, value.to_vec()))) + } + Poll::Ready(Some(sled::Event::Remove { key })) => { + let key = std::str::from_utf8(&key).unwrap().to_owned(); + Poll::Ready(Some(WatchEvent::Delete(key))) + } + } + } + + fn size_hint(&self) -> (usize, Option<usize>) { + self.subscriber.size_hint() + } +} + +#[cfg(test)] +mod tests { + use super::{KeyValueStore, SledClient, Watch, WatchEvent}; + + use crate::cluster::storage::{Keyspace, Operation}; + use crate::state::with_locks; + use futures::StreamExt; + use std::result::Result; + + fn create_instance() -> Result<SledClient, Box<dyn std::error::Error>> { + Ok(SledClient::try_new_temporary()?) + } + + #[tokio::test] + async fn put_read() -> Result<(), Box<dyn std::error::Error>> { + let client = create_instance()?; + let key = "key"; + let value = "value".as_bytes(); + client + .put(Keyspace::Slots, key.to_owned(), value.to_vec()) + .await?; + assert_eq!(client.get(Keyspace::Slots, key).await?, value); + Ok(()) + } + + // #[tokio::test] + // async fn multiple_operation() -> Result<(), Box<dyn std::error::Error>> { + // let client = create_instance()?; + // let key = "key".to_string(); + // let value = "value".as_bytes().to_vec(); + // let locks = client + // .acquire_locks(vec![(Keyspace::ActiveJobs, ""), (Keyspace::Slots, "")]) + // .await?; + // + // let _r: ballista_core::error::Result<()> = with_locks(locks, async { + // let txn_ops = vec![ + // (Operation::Put(value.clone()), Keyspace::Slots, key.clone()), + // ( + // Operation::Put(value.clone()), + // Keyspace::ActiveJobs, + // key.clone(), + // ), + // ]; + // client.apply_txn(txn_ops).await?; + // Ok(()) + // }) + // .await; + // + // assert_eq!(client.get(Keyspace::Slots, key.as_str()).await?, value); + // assert_eq!(client.get(Keyspace::ActiveJobs, key.as_str()).await?, value); + // Ok(()) + // } + + #[tokio::test] + async fn read_empty() -> Result<(), Box<dyn std::error::Error>> { + let client = create_instance()?; + let key = "key"; + let empty: &[u8] = &[]; + assert_eq!(client.get(Keyspace::Slots, key).await?, empty); + Ok(()) + } + + #[tokio::test] + async fn read_prefix() -> Result<(), Box<dyn std::error::Error>> { + let client = create_instance()?; + let key = "key"; + let value = "value".as_bytes(); + client + .put(Keyspace::Slots, format!("{}/1", key), value.to_vec()) + .await?; + client + .put(Keyspace::Slots, format!("{}/2", key), value.to_vec()) + .await?; + assert_eq!( + client.get_from_prefix(Keyspace::Slots, key).await?, + vec![ + ("/Slots/key/1".to_owned(), value.to_vec()), + ("/Slots/key/2".to_owned(), value.to_vec()) + ] + ); + Ok(()) + } + + #[tokio::test] + async fn read_watch() -> Result<(), Box<dyn std::error::Error>> { + let client = create_instance()?; + let key = "key"; + let value = "value".as_bytes(); + let mut watch: Box<dyn Watch<Item = WatchEvent>> = + client.watch(Keyspace::Slots, key.to_owned()).await?; + client + .put(Keyspace::Slots, key.to_owned(), value.to_vec()) + .await?; + assert_eq!( + watch.next().await, + Some(WatchEvent::Put( + format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), + value.to_owned() + )) + ); + let value2 = "value2".as_bytes(); + client + .put(Keyspace::Slots, key.to_owned(), value2.to_vec()) + .await?; + assert_eq!( + watch.next().await, + Some(WatchEvent::Put( + format!("/{:?}/{}", Keyspace::Slots, key.to_owned()), + value2.to_owned() + )) + ); + watch.cancel().await?; + Ok(()) + } +} diff --git a/ballista/scheduler/src/lib.rs b/ballista/scheduler/src/lib.rs index da21296f..cd6f047b 100644 --- a/ballista/scheduler/src/lib.rs +++ b/ballista/scheduler/src/lib.rs @@ -18,6 +18,7 @@ #![doc = include_str ! ("../README.md")] pub mod api; +pub mod cluster; pub mod config; pub mod display; pub mod metrics; diff --git a/ballista/scheduler/src/scheduler_process.rs b/ballista/scheduler/src/scheduler_process.rs index 44cf7bf4..440baca2 100644 --- a/ballista/scheduler/src/scheduler_process.rs +++ b/ballista/scheduler/src/scheduler_process.rs @@ -36,6 +36,7 @@ use ballista_core::utils::create_grpc_server; use ballista_core::BALLISTA_VERSION; use crate::api::{get_routes, EitherBody, Error}; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::flight_sql::FlightSqlServiceImpl; use crate::metrics::default_metrics_collector; @@ -46,8 +47,8 @@ use crate::state::backend::StateBackendClient; pub async fn start_server( scheduler_name: String, + cluster: BallistaCluster, config_backend: Arc<dyn StateBackendClient>, - cluster_state: Arc<dyn ClusterState>, addr: SocketAddr, config: SchedulerConfig, ) -> Result<()> { @@ -67,7 +68,7 @@ pub async fn start_server( SchedulerServer::new( scheduler_name, config_backend.clone(), - cluster_state, + cluster, BallistaCodec::default(), config, metrics_collector, diff --git a/ballista/scheduler/src/scheduler_server/mod.rs b/ballista/scheduler/src/scheduler_server/mod.rs index f39e0d1b..1feb01b1 100644 --- a/ballista/scheduler/src/scheduler_server/mod.rs +++ b/ballista/scheduler/src/scheduler_server/mod.rs @@ -29,6 +29,7 @@ use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_proto::logical_plan::AsLogicalPlan; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::SchedulerMetricsCollector; use log::{error, warn}; @@ -70,14 +71,14 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T pub fn new( scheduler_name: String, config_backend: Arc<dyn StateBackendClient>, - cluster_state: Arc<dyn ClusterState>, + cluster: BallistaCluster, codec: BallistaCodec<T, U>, config: SchedulerConfig, metrics_collector: Arc<dyn SchedulerMetricsCollector>, ) -> Self { let state = Arc::new(SchedulerState::new( config_backend, - cluster_state, + cluster, default_session_builder, codec, scheduler_name.clone(), @@ -103,7 +104,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T pub fn with_session_builder( scheduler_name: String, config_backend: Arc<dyn StateBackendClient>, - cluster_backend: Arc<dyn ClusterState>, + cluster: BallistaCluster, + codec: BallistaCodec<T, U>, config: SchedulerConfig, session_builder: SessionBuilder, @@ -111,7 +113,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T ) -> Self { let state = Arc::new(SchedulerState::new( config_backend, - cluster_backend, + cluster, session_builder, codec, scheduler_name.clone(), @@ -138,7 +140,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T pub(crate) fn with_task_launcher( scheduler_name: String, config_backend: Arc<dyn StateBackendClient>, - cluster_backend: Arc<dyn ClusterState>, + cluster: BallistaCluster, codec: BallistaCodec<T, U>, config: SchedulerConfig, metrics_collector: Arc<dyn SchedulerMetricsCollector>, @@ -146,7 +148,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerServer<T ) -> Self { let state = Arc::new(SchedulerState::with_task_launcher( config_backend, - cluster_backend, + cluster, default_session_builder, codec, scheduler_name.clone(), diff --git a/ballista/scheduler/src/standalone.rs b/ballista/scheduler/src/standalone.rs index 0fe608f8..c8405347 100644 --- a/ballista/scheduler/src/standalone.rs +++ b/ballista/scheduler/src/standalone.rs @@ -15,13 +15,16 @@ // specific language governing permissions and limitations // under the License. +use crate::cluster::event::ClusterEventSender; +use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState}; +use crate::cluster::BallistaCluster; use crate::config::SchedulerConfig; use crate::metrics::default_metrics_collector; use crate::state::backend::cluster::DefaultClusterState; use crate::{scheduler_server::SchedulerServer, state::backend::sled::SledClient}; use ballista_core::serde::protobuf::PhysicalPlanNode; use ballista_core::serde::BallistaCodec; -use ballista_core::utils::create_grpc_server; +use ballista_core::utils::{create_grpc_server, default_session_builder}; use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, @@ -36,11 +39,19 @@ pub async fn new_standalone_scheduler() -> Result<SocketAddr> { let metrics_collector = default_metrics_collector()?; + let cluster = BallistaCluster::new( + Arc::new(InMemoryClusterState::default()), + Arc::new(InMemoryJobState::new( + "localhost:50050", + default_session_builder, + )), + ); + let mut scheduler_server: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> = SchedulerServer::new( "localhost:50050".to_owned(), backend.clone(), - Arc::new(DefaultClusterState::new(backend)), + cluster, BallistaCodec::default(), SchedulerConfig::default(), metrics_collector, diff --git a/ballista/scheduler/src/state/execution_graph.rs b/ballista/scheduler/src/state/execution_graph.rs index 976ec5f5..a86f82ad 100644 --- a/ballista/scheduler/src/state/execution_graph.rs +++ b/ballista/scheduler/src/state/execution_graph.rs @@ -33,9 +33,10 @@ use log::{error, info, warn}; use ballista_core::error::{BallistaError, Result}; use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec}; use ballista_core::serde::protobuf::failed_task::FailedReason; +use ballista_core::serde::protobuf::job_status::Status; use ballista_core::serde::protobuf::{ self, execution_graph_stage::StageType, FailedTask, JobStatus, QueuedJob, ResultLost, - SuccessfulJob, TaskStatus, + RunningJob, SuccessfulJob, TaskStatus, }; use ballista_core::serde::protobuf::{job_status, FailedJob, ShuffleWritePartition}; use ballista_core::serde::protobuf::{task_status, RunningTask}; @@ -101,8 +102,8 @@ mod execution_stage; /// publish its outputs to the `ExecutionGraph`s `output_locations` representing the final query results. #[derive(Clone)] pub struct ExecutionGraph { - /// Curator scheduler name - scheduler_id: String, + /// Curator scheduler name. Can be `None` is `ExecutionGraph` is not currently curated by any scheduler + scheduler_id: Option<String>, /// ID for this job job_id: String, /// Job name, can be empty string @@ -158,12 +159,14 @@ impl ExecutionGraph { let stages = builder.build(shuffle_stages)?; Ok(Self { - scheduler_id: scheduler_id.to_string(), + scheduler_id: Some(scheduler_id.to_string()), job_id: job_id.to_string(), job_name: job_name.to_string(), session_id: session_id.to_string(), status: JobStatus { - status: Some(job_status::Status::Queued(QueuedJob {})), + job_id: job_id.to_string(), + job_name: job_name.to_string(), + status: Some(job_status::Status::Queued(QueuedJob { queued_at })), }, queued_at, start_time: timestamp_millis(), @@ -221,6 +224,12 @@ impl ExecutionGraph { .all(|s| matches!(s, ExecutionStage::Successful(_))) } + pub fn is_complete(&self) -> bool { + self.stages + .values() + .all(|s| matches!(s, ExecutionStage::Successful(_))) + } + /// Revive the execution graph by converting the resolved stages to running stages /// If any stages are converted, return true; else false. pub fn revive(&mut self) -> bool { @@ -838,6 +847,7 @@ impl ExecutionGraph { self.status, JobStatus { status: Some(job_status::Status::Failed(_)), + .. } ) { warn!("Call pop_next_task on failed Job"); @@ -1211,7 +1221,14 @@ impl ExecutionGraph { /// fail job with error message pub fn fail_job(&mut self, error: String) { self.status = JobStatus { - status: Some(job_status::Status::Failed(FailedJob { error })), + job_id: self.job_id.clone(), + job_name: self.job_name.clone(), + status: Some(job_status::Status::Failed(FailedJob { + error, + queued_at: self.queued_at, + started_at: self.start_time, + ended_at: self.end_time, + })), }; } @@ -1231,8 +1248,14 @@ impl ExecutionGraph { .collect::<Result<Vec<_>>>()?; self.status = JobStatus { + job_id: self.job_id.clone(), + job_name: self.job_name.clone(), status: Some(job_status::Status::Successful(SuccessfulJob { partition_location, + + queued_at: self.queued_at, + started_at: self.start_time, + ended_at: self.end_time, })), }; self.end_time = SystemTime::now() @@ -1248,6 +1271,18 @@ impl ExecutionGraph { self.failed_stage_attempts.remove(&stage_id); } + pub(crate) fn disown(&mut self) { + self.scheduler_id = None; + + if let JobStatus { + status: Some(Status::Running(RunningJob { scheduler, .. })), + .. + } = &mut self.status + { + std::mem::take(scheduler); + } + } + pub(crate) async fn decode_execution_graph< T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan, @@ -1309,7 +1344,7 @@ impl ExecutionGraph { .collect(); Ok(ExecutionGraph { - scheduler_id: proto.scheduler_id, + scheduler_id: (!proto.scheduler_id.is_empty()).then_some(proto.scheduler_id), job_id: proto.job_id, job_name: proto.job_name, session_id: proto.session_id, @@ -1399,7 +1434,7 @@ impl ExecutionGraph { stages, output_partitions: graph.output_partitions as u64, output_locations, - scheduler_id: graph.scheduler_id, + scheduler_id: graph.scheduler_id.unwrap_or_default(), task_id_gen: graph.task_id_gen as u32, failed_attempts, }) diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs b/ballista/scheduler/src/state/execution_graph_dot.rs index 6e65612f..521271e0 100644 --- a/ballista/scheduler/src/state/execution_graph_dot.rs +++ b/ballista/scheduler/src/state/execution_graph_dot.rs @@ -45,20 +45,20 @@ use std::fmt::{self, Write}; use std::sync::Arc; /// Utility for producing dot diagrams from execution graphs -pub struct ExecutionGraphDot { - graph: Arc<ExecutionGraph>, +pub struct ExecutionGraphDot<'a> { + graph: &'a ExecutionGraph, } -impl ExecutionGraphDot { +impl<'a> ExecutionGraphDot<'a> { /// Create a DOT graph from the provided ExecutionGraph - pub fn generate(graph: Arc<ExecutionGraph>) -> Result<String, fmt::Error> { + pub fn generate(graph: &'a ExecutionGraph) -> Result<String, fmt::Error> { let mut dot = Self { graph }; dot._generate() } /// Create a DOT graph for one query stage from the provided ExecutionGraph pub fn generate_for_query_stage( - graph: Arc<ExecutionGraph>, + graph: &ExecutionGraph, stage_id: usize, ) -> Result<String, fmt::Error> { if let Some(stage) = graph.stages().get(&stage_id) { @@ -503,7 +503,7 @@ filter_expr="] #[tokio::test] async fn query_stage() -> Result<()> { let graph = test_graph().await?; - let dot = ExecutionGraphDot::generate_for_query_stage(Arc::new(graph), 3) + let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 3) .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?; let expected = r#"digraph G { @@ -595,7 +595,7 @@ filter_expr="] #[tokio::test] async fn query_stage_optimized() -> Result<()> { let graph = test_graph_optimized().await?; - let dot = ExecutionGraphDot::generate_for_query_stage(Arc::new(graph), 4) + let dot = ExecutionGraphDot::generate_for_query_stage(&graph, 4) .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?; let expected = r#"digraph G { diff --git a/ballista/scheduler/src/state/executor_manager.rs b/ballista/scheduler/src/state/executor_manager.rs index a6234528..6bd18336 100644 --- a/ballista/scheduler/src/state/executor_manager.rs +++ b/ballista/scheduler/src/state/executor_manager.rs @@ -17,13 +17,13 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use crate::state::backend::TaskDistribution; +use crate::cluster::TaskDistribution; use ballista_core::error::{BallistaError, Result}; use ballista_core::serde::protobuf; +use crate::cluster::ClusterState; use crate::config::SlotsPolicy; -use crate::state::backend::cluster::ClusterState; use crate::state::execution_graph::RunningTaskInfo; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::protobuf::{ diff --git a/ballista/scheduler/src/state/mod.rs b/ballista/scheduler/src/state/mod.rs index 99791457..da0a36cb 100644 --- a/ballista/scheduler/src/state/mod.rs +++ b/ballista/scheduler/src/state/mod.rs @@ -31,6 +31,7 @@ use crate::state::executor_manager::{ExecutorManager, ExecutorReservation}; use crate::state::session_manager::SessionManager; use crate::state::task_manager::{TaskLauncher, TaskManager}; +use crate::cluster::{BallistaCluster, JobState}; use crate::config::SchedulerConfig; use crate::state::execution_graph::TaskDescription; use backend::cluster::ClusterState; @@ -100,7 +101,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, #[cfg(test)] pub fn new_with_default_scheduler_name( config_client: Arc<dyn StateBackendClient>, - cluster_state: Arc<dyn ClusterState>, + cluster: BallistaCluster, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, ) -> Self { @@ -116,7 +117,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, pub fn new( config_client: Arc<dyn StateBackendClient>, - cluster_state: Arc<dyn ClusterState>, + cluster: BallistaCluster, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_name: String, @@ -124,11 +125,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, ) -> Self { Self { executor_manager: ExecutorManager::new( - cluster_state, + cluster.cluster_state(), config.executor_slots_policy, ), task_manager: TaskManager::new( config_client.clone(), + cluster.job_state(), session_builder, codec.clone(), scheduler_name, @@ -142,7 +144,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, #[allow(dead_code)] pub(crate) fn with_task_launcher( config_client: Arc<dyn StateBackendClient>, - cluster_state: Arc<dyn ClusterState>, + cluster: BallistaCluster, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_name: String, @@ -151,11 +153,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, ) -> Self { Self { executor_manager: ExecutorManager::new( - cluster_state, + cluster.cluster_state(), config.executor_slots_policy, ), task_manager: TaskManager::with_launcher( config_client.clone(), + cluster.job_state(), session_builder, codec.clone(), scheduler_name, @@ -414,7 +417,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, job_id.clone(), self.config.finished_job_data_clean_up_interval_seconds, ); - self.task_manager.delete_successful_job_delayed( + self.task_manager.clean_up_job_delayed( job_id, self.config.finished_job_state_clean_up_interval_seconds, ); @@ -423,7 +426,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerState<T, /// Spawn a delayed future to clean up job data on both Scheduler and Executors pub(crate) fn clean_up_failed_job(&self, job_id: String) { self.executor_manager.clean_up_job_data(job_id.clone()); - self.task_manager.clean_up_failed_job_delayed( + self.task_manager.clean_up_job_delayed( job_id, self.config.finished_job_state_clean_up_interval_seconds, ); diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 53b59219..25dcdaea 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -30,6 +30,7 @@ use ballista_core::error::Result; use crate::state::backend::Keyspace::{CompletedJobs, FailedJobs}; use crate::state::session_manager::create_datafusion_context; +use crate::cluster::JobState; use ballista_core::serde::protobuf::{ self, job_status, FailedJob, JobStatus, MultiTaskDefinition, TaskDefinition, TaskId, TaskStatus, @@ -45,10 +46,12 @@ use log::{debug, error, info, warn}; use rand::distributions::Alphanumeric; use rand::{thread_rng, Rng}; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use std::time::Duration; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::RwLock; +use tokio::sync::RwLockReadGuard; use tracing::trace; type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>; @@ -108,6 +111,7 @@ impl TaskLauncher for DefaultTaskLauncher { #[derive(Clone)] pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> { state: Arc<dyn StateBackendClient>, + job_state: Arc<dyn JobState>, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_id: String, @@ -145,12 +149,14 @@ pub struct UpdatedStages { impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> { pub fn new( state: Arc<dyn StateBackendClient>, + job_state: Arc<dyn JobState>, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_id: String, ) -> Self { Self { state, + job_state, session_builder, codec, scheduler_id: scheduler_id.clone(), @@ -162,6 +168,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> #[allow(dead_code)] pub(crate) fn with_launcher( state: Arc<dyn StateBackendClient>, + job_state: Arc<dyn JobState>, session_builder: SessionBuilder, codec: BallistaCodec<T, U>, scheduler_id: String, @@ -169,6 +176,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> ) -> Self { Self { state, + job_state, session_builder, codec, scheduler_id, @@ -197,12 +205,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> queued_at, )?; info!("Submitting execution graph: {:?}", graph); - self.state - .put( - Keyspace::ActiveJobs, - job_id.to_owned(), - self.encode_execution_graph(graph.clone())?, - ) + + self.job_state + .submit_job(job_id.to_string(), &graph) .await?; graph.revive(); @@ -214,36 +219,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> /// Get a list of active job ids pub async fn get_jobs(&self) -> Result<Vec<JobOverview>> { - let mut job_ids = vec![]; - for job_id in self.state.scan_keys(Keyspace::ActiveJobs).await? { - job_ids.push(job_id); - } - for job_id in self.state.scan_keys(Keyspace::CompletedJobs).await? { - job_ids.push(job_id); - } - for job_id in self.state.scan_keys(Keyspace::FailedJobs).await? { - job_ids.push(job_id); - } + let job_ids = self.job_state.get_jobs().await?; let mut jobs = vec![]; for job_id in &job_ids { - let graph = self.get_execution_graph(job_id).await?; - - let mut completed_stages = 0; - for stage in graph.stages().values() { - if let ExecutionStage::Successful(_) = stage { - completed_stages += 1; - } + if let Some(cached) = self.active_job_cache.get(job_id) { + let graph = cached.execution_graph.read().await; + jobs.push(graph.deref().into()); + } else { + let graph = self.job_state + .get_execution_graph(job_id) + .await? + .ok_or_else(|| BallistaError::Internal(format!("Error getting job overview, no execution graph found for job {job_id}")))?; + jobs.push((&graph).into()); } - jobs.push(JobOverview { - job_id: job_id.clone(), - job_name: graph.job_name().to_string(), - status: graph.status(), - start_time: graph.start_time(), - end_time: graph.end_time(), - num_stages: graph.stage_count(), - completed_stages, - }); } Ok(jobs) } @@ -251,36 +240,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> /// Get the status of of a job. First look in the active cache. /// If no one found, then in the Active/Completed jobs, and then in Failed jobs pub async fn get_job_status(&self, job_id: &str) -> Result<Option<JobStatus>> { - if let Some(graph) = self.get_active_execution_graph(job_id).await { - let status = graph.read().await.status(); - Ok(Some(status)) - } else if let Ok(graph) = self.get_execution_graph(job_id).await { - Ok(Some(graph.status())) - } else { - let value = self.state.get(Keyspace::FailedJobs, job_id).await?; - - if !value.is_empty() { - let status = decode_protobuf(&value)?; - Ok(Some(status)) - } else { - Ok(None) - } - } + self.job_state.get_job_status(job_id).await } /// Get the execution graph of of a job. First look in the active cache. /// If no one found, then in the Active/Completed jobs. - pub async fn get_job_execution_graph( + pub(crate) async fn get_job_execution_graph( &self, job_id: &str, ) -> Result<Option<Arc<ExecutionGraph>>> { - if let Some(graph) = self.get_active_execution_graph(job_id).await { - Ok(Some(Arc::new(graph.read().await.clone()))) - } else if let Ok(graph) = self.get_execution_graph(job_id).await { - Ok(Some(Arc::new(graph))) + if let Some(cached) = self.active_job_cache.get(job_id) { + let guard = cached.execution_graph.read().await; + + Ok(Some(Arc::new(guard.deref().clone()))) } else { - // if the job failed then we return no graph for now - Ok(None) + let graph = self.job_state.get_execution_graph(job_id).await?; + + Ok(graph.map(Arc::new)) } } @@ -305,9 +281,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> let num_tasks = statuses.len(); debug!("Updating {} tasks in job {}", num_tasks, job_id); - let graph = self.get_active_execution_graph(&job_id).await; - let job_events = if let Some(graph) = graph { - let mut graph = graph.write().await; + // let graph = self.get_active_execution_graph(&job_id).await; + let job_events = if let Some(cached) = self.active_job_cache.get(&job_id) { + let mut graph = cached.execution_graph.write().await; graph.update_task_status( executor, statuses, @@ -387,16 +363,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> /// and remove the job from ActiveJobs pub(crate) async fn succeed_job(&self, job_id: &str) -> Result<()> { debug!("Moving job {} from Active to Success", job_id); - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; - with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?; if let Some(graph) = self.remove_active_execution_graph(job_id).await { let graph = graph.read().await.clone(); if graph.is_successful() { - let value = self.encode_execution_graph(graph)?; - self.state - .put(Keyspace::CompletedJobs, job_id.to_owned(), value) - .await?; + self.job_state.save_job(job_id, &graph).await?; } else { error!("Job {} has not finished and cannot be completed", job_id); return Ok(()); @@ -503,15 +474,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> warn!("Rollback Execution Graph state change since it did not persisted due to a possible connection error.") }; } else { - info!("Fail to find job {} in the cache", job_id); - let status = JobStatus { - status: Some(job_status::Status::Failed(FailedJob { - error: failure_reason.clone(), - })), - }; - let value = encode_protobuf(&status)?; - let txn_ops = txn_operations(value); - self.state.apply_txn(txn_ops).await?; + warn!( + "Fail to find job {} in the cache, not updating status to failed", + job_id + ); + // let status = JobStatus { + // job_id: job_id.to_string(), + // job_name: "".to_string(), + // status: Some(job_status::Status::Failed(FailedJob { + // error: failure_reason.clone(), + // })), + // }; + // let value = encode_protobuf(&status)?; + // let txn_ops = txn_operations(value); + // self.state.apply_txn(txn_ops).await?; }; Ok(()) @@ -527,12 +503,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> graph.revive(); let graph = graph.clone(); - let new_tasks = graph.available_tasks() - curr_available_tasks; + self.job_state.save_job(job_id, &graph).await?; - let value = self.encode_execution_graph(graph)?; - self.state - .put(Keyspace::ActiveJobs, job_id.to_owned(), value) - .await?; + let new_tasks = graph.available_tasks() - curr_available_tasks; Ok(new_tasks) } else { @@ -560,20 +533,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> } } - let lock = self.state.lock(Keyspace::ActiveJobs, "").await?; - with_lock(lock, async { - // Transactional update graphs - let txn_ops: Vec<(Operation, Keyspace, String)> = updated_graphs - .into_iter() - .map(|(job_id, graph)| { - let value = self.encode_execution_graph(graph)?; - Ok((Operation::Put(value), Keyspace::ActiveJobs, job_id)) - }) - .collect::<Result<Vec<_>>>()?; - self.state.apply_txn(txn_ops).await?; - Ok(running_tasks_to_cancel) - }) - .await + Ok(running_tasks_to_cancel) } /// Retrieve the number of available tasks for the given job. The value returned @@ -763,46 +723,25 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> &self, job_id: &str, ) -> Result<ExecutionGraph> { - let value = self.state.get(Keyspace::ActiveJobs, job_id).await?; + if let Some(active) = self.active_job_cache.get(job_id) { + let guard = active.execution_graph.read().await; - if value.is_empty() { - let value = self.state.get(Keyspace::CompletedJobs, job_id).await?; - self.decode_execution_graph(value).await + Ok(guard.clone()) } else { - self.decode_execution_graph(value).await + let graph = self + .job_state + .get_execution_graph(job_id) + .await? + .ok_or_else(|| { + BallistaError::Internal(format!( + "No ExecutionGraph found for job {job_id}" + )) + })?; + + Ok(graph) } } - async fn get_session(&self, session_id: &str) -> Result<Arc<SessionContext>> { - let value = self.state.get(Keyspace::Sessions, session_id).await?; - - let settings: protobuf::SessionSettings = decode_protobuf(&value)?; - - let mut config_builder = BallistaConfig::builder(); - for kv_pair in &settings.configs { - config_builder = config_builder.set(&kv_pair.key, &kv_pair.value); - } - let config = config_builder.build()?; - - Ok(create_datafusion_context(&config, self.session_builder)) - } - - async fn decode_execution_graph(&self, value: Vec<u8>) -> Result<ExecutionGraph> { - let proto: protobuf::ExecutionGraph = decode_protobuf(&value)?; - - let session_id = &proto.session_id; - - let session_ctx = self.get_session(session_id).await?; - - ExecutionGraph::decode_execution_graph(proto, &self.codec, &session_ctx).await - } - - fn encode_execution_graph(&self, graph: ExecutionGraph) -> Result<Vec<u8>> { - let proto = ExecutionGraph::encode_execution_graph(graph, &self.codec)?; - - encode_protobuf(&proto) - } - /// Generate a new random Job ID pub fn generate_job_id(&self) -> String { let mut rng = thread_rng(); @@ -814,55 +753,21 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> } /// Clean up a failed job in FailedJobs Keyspace by delayed clean_up_interval seconds - pub(crate) fn clean_up_failed_job_delayed( - &self, - job_id: String, - clean_up_interval: u64, - ) { + pub(crate) fn clean_up_job_delayed(&self, job_id: String, clean_up_interval: u64) { if clean_up_interval == 0 { info!("The interval is 0 and the clean up for the failed job state {} will not triggered", job_id); return; } - self.delete_from_state_backend_delayed(FailedJobs, job_id, clean_up_interval) - } - /// Clean up a successful job in CompletedJobs Keyspace by delayed clean_up_interval seconds - pub(crate) fn delete_successful_job_delayed( - &self, - job_id: String, - clean_up_interval: u64, - ) { - if clean_up_interval == 0 { - info!("The interval is 0 and the clean up for the successful job state {} will not triggered", job_id); - return; - } - self.delete_from_state_backend_delayed(CompletedJobs, job_id, clean_up_interval) - } - - /// Clean up entries in some keyspace by delayed clean_up_interval seconds - fn delete_from_state_backend_delayed( - &self, - keyspace: Keyspace, - key: String, - clean_up_interval: u64, - ) { - let state = self.state.clone(); + let state = self.job_state.clone(); tokio::spawn(async move { + let job_id = job_id; tokio::time::sleep(Duration::from_secs(clean_up_interval)).await; - Self::delete_from_state_backend(state, keyspace, &key).await + if let Err(err) = state.remove_job(&job_id).await { + error!("Failed to delete job {job_id}: {err:?}"); + } }); } - - async fn delete_from_state_backend( - state: Arc<dyn StateBackendClient>, - keyspace: Keyspace, - key: &str, - ) -> Result<()> { - let lock = state.lock(keyspace.clone(), "").await?; - with_lock(lock, state.delete(keyspace, key)).await?; - - Ok(()) - } } pub struct JobOverview { @@ -874,3 +779,24 @@ pub struct JobOverview { pub num_stages: usize, pub completed_stages: usize, } + +impl From<&ExecutionGraph> for JobOverview { + fn from(value: &ExecutionGraph) -> Self { + let mut completed_stages = 0; + for stage in value.stages().values() { + if let ExecutionStage::Successful(_) = stage { + completed_stages += 1; + } + } + + Self { + job_id: value.job_id().to_string(), + job_name: value.job_name().to_string(), + status: value.status(), + start_time: value.start_time(), + end_time: value.end_time(), + num_stages: value.stage_count(), + completed_stages, + } + } +}
